michalursa commented on a change in pull request #10290: URL: https://github.com/apache/arrow/pull/10290#discussion_r688249565
########## File path: cpp/src/arrow/compute/exec/key_compare_avx2.cc ########## @@ -25,160 +25,545 @@ namespace compute { #if defined(ARROW_HAVE_AVX2) -uint32_t KeyCompare::CompareFixedLength_UpTo8B_avx2( - uint32_t num_rows, const uint32_t* left_to_right_map, uint8_t* match_bytevector, - uint32_t length, const uint8_t* rows_left, const uint8_t* rows_right) { - ARROW_DCHECK(length <= 8); - __m256i offset_left = _mm256_setr_epi64x(0, length, length * 2, length * 3); - __m256i offset_left_incr = _mm256_set1_epi64x(length * 4); - __m256i mask = _mm256_set1_epi64x(~0ULL >> (8 * (8 - length))); - - constexpr uint32_t unroll = 4; - for (uint32_t i = 0; i < num_rows / unroll; ++i) { - auto key_left = _mm256_i64gather_epi64( - reinterpret_cast<arrow::util::int64_for_gather_t*>(rows_left), offset_left, 1); - offset_left = _mm256_add_epi64(offset_left, offset_left_incr); - __m128i offset_right = - _mm_loadu_si128(reinterpret_cast<const __m128i*>(left_to_right_map) + i); - offset_right = _mm_mullo_epi32(offset_right, _mm_set1_epi32(length)); - - auto key_right = _mm256_i32gather_epi64( - reinterpret_cast<arrow::util::int64_for_gather_t*>(rows_right), offset_right, 1); - uint32_t cmp = _mm256_movemask_epi8(_mm256_cmpeq_epi64( - _mm256_and_si256(key_left, mask), _mm256_and_si256(key_right, mask))); - reinterpret_cast<uint32_t*>(match_bytevector)[i] &= cmp; +template <bool use_selection> +uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( + uint32_t id_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx, + const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows, + uint8_t* match_bytevector) { + if (!rows.has_any_nulls(ctx) && !col.data(0)) { + return num_rows_to_compare; } + if (!col.data(0)) { + // Remove rows from the result for which the column value is a null + const uint8_t* null_masks = rows.null_masks(); + uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row; - uint32_t num_rows_processed = num_rows - (num_rows % unroll); - return num_rows_processed; -} + uint32_t num_processed = 0; + constexpr uint32_t unroll = 8; + for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) { + __m256i irow_right; + if (use_selection) { + __m256i irow_left = _mm256_cvtepu16_epi32( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i)); + irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4); + } else { + irow_right = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i); + } + __m256i bitid = + _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8)); + bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col)); + __m256i right = + _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1); + right = _mm256_and_si256( + _mm256_set1_epi32(1), + _mm256_srlv_epi32(right, _mm256_and_si256(bitid, _mm256_set1_epi32(7)))); + __m256i cmp = _mm256_cmpeq_epi32(right, _mm256_setzero_si256()); + uint32_t result_lo = + _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp))); + uint32_t result_hi = + _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1))); + reinterpret_cast<uint64_t*>(match_bytevector)[i] &= + result_lo | (static_cast<uint64_t>(result_hi) << 32); + } + num_processed = num_rows_to_compare / unroll * unroll; + return num_processed; + } else if (!rows.has_any_nulls(ctx)) { + // Remove rows from the result for which the column value on left side is null + const uint8_t* non_nulls = col.data(0); + ARROW_DCHECK(non_nulls); + uint32_t num_processed = 0; + constexpr uint32_t unroll = 8; + for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) { + __m256i cmp; + if (use_selection) { + __m256i irow_left = _mm256_cvtepu16_epi32( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i)); + irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(col.bit_offset(0))); + __m256i left = _mm256_i32gather_epi32((const int*)non_nulls, + _mm256_srli_epi32(irow_left, 3), 1); + left = _mm256_and_si256( + _mm256_set1_epi32(1), + _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7)))); + cmp = _mm256_cmpeq_epi32(left, _mm256_set1_epi32(1)); + } else { + __m256i left = _mm256_cvtepu8_epi32(_mm_set1_epi8(static_cast<uint8_t>( + reinterpret_cast<const uint16_t*>(non_nulls + i)[0] >> col.bit_offset(0)))); + __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128); + cmp = _mm256_cmpeq_epi32(_mm256_and_si256(left, bits), bits); + } + uint32_t result_lo = + _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp))); + uint32_t result_hi = + _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1))); + reinterpret_cast<uint64_t*>(match_bytevector)[i] &= + result_lo | (static_cast<uint64_t>(result_hi) << 32); + num_processed = num_rows_to_compare / unroll * unroll; + } + return num_processed; + } else { + const uint8_t* null_masks = rows.null_masks(); + uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row; + const uint8_t* non_nulls = col.data(0); + ARROW_DCHECK(non_nulls); -uint32_t KeyCompare::CompareFixedLength_UpTo16B_avx2( - uint32_t num_rows, const uint32_t* left_to_right_map, uint8_t* match_bytevector, - uint32_t length, const uint8_t* rows_left, const uint8_t* rows_right) { - ARROW_DCHECK(length <= 16); - - constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL; - constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL; - - __m256i mask = - _mm256_cmpgt_epi8(_mm256_set1_epi8(length), - _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15, - kByteSequence0To7, kByteSequence8To15)); - const uint8_t* key_left_ptr = rows_left; - - constexpr uint32_t unroll = 2; - for (uint32_t i = 0; i < num_rows / unroll; ++i) { - auto key_left = _mm256_inserti128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128(reinterpret_cast<const __m128i*>(key_left_ptr))), - _mm_loadu_si128(reinterpret_cast<const __m128i*>(key_left_ptr + length)), 1); - key_left_ptr += length * 2; - auto key_right = _mm256_inserti128_si256( - _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i*>( - rows_right + length * left_to_right_map[2 * i]))), - _mm_loadu_si128(reinterpret_cast<const __m128i*>( - rows_right + length * left_to_right_map[2 * i + 1])), - 1); - __m256i cmp = _mm256_cmpeq_epi64(_mm256_and_si256(key_left, mask), - _mm256_and_si256(key_right, mask)); - cmp = _mm256_and_si256(cmp, _mm256_shuffle_epi32(cmp, 0xee)); // 0b11101110 - cmp = _mm256_permute4x64_epi64(cmp, 0x08); // 0b00001000 - reinterpret_cast<uint16_t*>(match_bytevector)[i] &= - (_mm256_movemask_epi8(cmp) & 0xffff); - } + uint32_t num_processed = 0; + constexpr uint32_t unroll = 8; + for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) { + __m256i left_null; + __m256i irow_right; + if (use_selection) { + __m256i irow_left = _mm256_cvtepu16_epi32( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i)); + irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4); + irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(col.bit_offset(0))); + __m256i left = _mm256_i32gather_epi32((const int*)non_nulls, + _mm256_srli_epi32(irow_left, 3), 1); + left = _mm256_and_si256( + _mm256_set1_epi32(1), + _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7)))); + left_null = _mm256_cmpeq_epi32(left, _mm256_setzero_si256()); + } else { + irow_right = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i); + __m256i left = _mm256_cvtepu8_epi32(_mm_set1_epi8(static_cast<uint8_t>( + reinterpret_cast<const uint16_t*>(non_nulls + i)[0] >> col.bit_offset(0)))); + __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128); + left_null = + _mm256_cmpeq_epi32(_mm256_and_si256(left, bits), _mm256_setzero_si256()); + } + __m256i bitid = + _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8)); + bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col)); + __m256i right = + _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1); + right = _mm256_and_si256( + _mm256_set1_epi32(1), + _mm256_srlv_epi32(right, _mm256_and_si256(bitid, _mm256_set1_epi32(7)))); + __m256i right_null = _mm256_cmpeq_epi32(right, _mm256_set1_epi32(1)); - uint32_t num_rows_processed = num_rows - (num_rows % unroll); - return num_rows_processed; -} + uint64_t left_null_64 = + static_cast<uint32_t>(_mm256_movemask_epi8( + _mm256_cvtepi32_epi64(_mm256_castsi256_si128(left_null)))) | + (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(left_null, 1))))) + << 32); -uint32_t KeyCompare::CompareFixedLength_avx2(uint32_t num_rows, - const uint32_t* left_to_right_map, - uint8_t* match_bytevector, uint32_t length, - const uint8_t* rows_left, - const uint8_t* rows_right) { - ARROW_DCHECK(length > 0); + uint64_t right_null_64 = + static_cast<uint32_t>(_mm256_movemask_epi8( + _mm256_cvtepi32_epi64(_mm256_castsi256_si128(right_null)))) | + (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_null, 1))))) + << 32); - constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL; - constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL; - constexpr uint64_t kByteSequence16To23 = 0x1716151413121110ULL; - constexpr uint64_t kByteSequence24To31 = 0x1f1e1d1c1b1a1918ULL; + reinterpret_cast<uint64_t*>(match_bytevector)[i] |= left_null_64 & right_null_64; + reinterpret_cast<uint64_t*>(match_bytevector)[i] &= ~(left_null_64 ^ right_null_64); + } + num_processed = num_rows_to_compare / unroll * unroll; + return num_processed; + } +} - // Non-zero length guarantees no underflow - int32_t num_loops_less_one = (static_cast<int32_t>(length) + 31) / 32 - 1; +template <bool use_selection, class COMPARE8_FN> +uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( + uint32_t offset_within_row, uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, + KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, + const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector, + COMPARE8_FN compare8_fn) { + bool is_fixed_length = rows.metadata().is_fixed_length; + if (is_fixed_length) { + uint32_t fixed_length = rows.metadata().fixed_length; + const uint8_t* rows_left = col.data(1); + const uint8_t* rows_right = rows.data(1); + constexpr uint32_t unroll = 8; + __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) { + if (use_selection) { + irow_left = _mm256_cvtepu16_epi32( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i)); + } + __m256i irow_right; + if (use_selection) { + irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4); + } else { + irow_right = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i); + } - __m256i tail_mask = - _mm256_cmpgt_epi8(_mm256_set1_epi8(length - num_loops_less_one * 32), - _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15, - kByteSequence16To23, kByteSequence24To31)); + __m256i offset_right = + _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(fixed_length)); + offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row)); - for (uint32_t irow_left = 0; irow_left < num_rows; ++irow_left) { - uint32_t irow_right = left_to_right_map[irow_left]; - uint32_t begin_left = length * irow_left; - uint32_t begin_right = length * irow_right; - const __m256i* key_left_ptr = - reinterpret_cast<const __m256i*>(rows_left + begin_left); - const __m256i* key_right_ptr = - reinterpret_cast<const __m256i*>(rows_right + begin_right); - __m256i result_or = _mm256_setzero_si256(); - int32_t i; - // length cannot be zero - for (i = 0; i < num_loops_less_one; ++i) { - __m256i key_left = _mm256_loadu_si256(key_left_ptr + i); - __m256i key_right = _mm256_loadu_si256(key_right_ptr + i); - result_or = _mm256_or_si256(result_or, _mm256_xor_si256(key_left, key_right)); + reinterpret_cast<uint64_t*>(match_bytevector)[i] = + compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right); + + if (!use_selection) { + irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8)); + } } + return num_rows_to_compare - (num_rows_to_compare % unroll); + } else { + const uint8_t* rows_left = col.data(1); + const uint32_t* offsets_right = rows.offsets(); + const uint8_t* rows_right = rows.data(2); + constexpr uint32_t unroll = 8; + __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) { + if (use_selection) { + irow_left = _mm256_cvtepu16_epi32( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i)); + } + __m256i irow_right; + if (use_selection) { + irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4); + } else { + irow_right = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i); + } + __m256i offset_right = + _mm256_i32gather_epi32((const int*)offsets_right, irow_right, 4); + offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row)); - __m256i key_left = _mm256_loadu_si256(key_left_ptr + i); - __m256i key_right = _mm256_loadu_si256(key_right_ptr + i); - result_or = _mm256_or_si256( - result_or, _mm256_and_si256(tail_mask, _mm256_xor_si256(key_left, key_right))); - int result = _mm256_testz_si256(result_or, result_or) * 0xff; - match_bytevector[irow_left] &= result; + reinterpret_cast<uint64_t*>(match_bytevector)[i] = + compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right); + + if (!use_selection) { + irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8)); + } + } + return num_rows_to_compare - (num_rows_to_compare % unroll); } +} - uint32_t num_rows_processed = num_rows; - return num_rows_processed; +template <bool use_selection> +uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( + uint32_t offset_within_row, uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, + KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col, + const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) { + uint32_t col_width = col.metadata().fixed_length; + if (col_width == 0) { + int bit_offset = col.bit_offset(1); + return CompareBinaryColumnToRowHelper_avx2<use_selection>( + offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, + ctx, col, rows, match_bytevector, + [bit_offset](const uint8_t* left_base, const uint8_t* right_base, + uint32_t irow_left_base, __m256i irow_left, __m256i offset_right) { + __m256i left; + if (use_selection) { + irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(bit_offset)); + left = _mm256_i32gather_epi32((const int*)left_base, + _mm256_srli_epi32(irow_left, 3), 1); + left = _mm256_and_si256( + _mm256_set1_epi32(1), + _mm256_srlv_epi32(left, + _mm256_and_si256(irow_left, _mm256_set1_epi32(7)))); + left = _mm256_mullo_epi32(left, _mm256_set1_epi32(0xff)); + } else { + __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128); + uint32_t start_bit_index = irow_left_base + bit_offset; + uint8_t left_bits_8 = + (reinterpret_cast<const uint16_t*>(left_base + start_bit_index / 8)[0] >> + (start_bit_index % 8)) & + 0xff; + left = _mm256_cmpeq_epi32( + _mm256_and_si256(bits, _mm256_set1_epi8(left_bits_8)), bits); + left = _mm256_and_si256(left, _mm256_set1_epi32(0xff)); + } + __m256i right = _mm256_i32gather_epi32((const int*)right_base, offset_right, 1); + right = _mm256_and_si256(right, _mm256_set1_epi32(0xff)); + __m256i cmp = _mm256_cmpeq_epi32(left, right); + uint32_t result_lo = + _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp))); + uint32_t result_hi = _mm256_movemask_epi8( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1))); + return result_lo | (static_cast<uint64_t>(result_hi) << 32); + }); + } else if (col_width == 1) { Review comment: I refactored that part a bit. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org