zanmato1984 commented on code in PR #45515:
URL: https://github.com/apache/arrow/pull/45515#discussion_r1953018134


##########
cpp/src/arrow/compute/key_map_internal_avx2.cc:
##########
@@ -385,33 +382,52 @@ int SwissTable::extract_group_ids_avx2(const int 
num_keys, const uint32_t* hashe
       _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, 
group_id);
     }
   } else {
+    int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+    int num_groupid_bytes = num_groupid_bits / 8;
+    uint32_t mask = num_groupid_bytes == 1   ? 0xFF
+                    : num_groupid_bytes == 2 ? 0xFFFF
+                                             : 0xFFFFFFFF;

Review Comment:
   Not particularly. This is just moving the original code.



##########
cpp/src/arrow/compute/key_map_internal.h:
##########
@@ -220,6 +257,12 @@ class ARROW_EXPORT SwissTable {
     return bits_stamp_;
   }
 
+  static uint32_t group_id_mask_from_num_groupid_bits(int64_t 
num_groupid_bits) {
+    return static_cast<uint32_t>((1ULL << num_groupid_bits) - 1);
+  }
+
+  static constexpr int bytes_status_in_block_ = 8;

Review Comment:
   Yes it is supposed to be. However I was also following the naming convention 
of several existing compile time constants in this class. I would like to to 
change them all in another PR to keep this one solely focused on the purpose 
the overflow prevention.



##########
cpp/src/arrow/compute/key_map_internal.cc:
##########
@@ -94,27 +94,32 @@ inline void SwissTable::search_block(uint64_t block, int 
stamp, int start_slot,
   *out_slot = static_cast<int>(CountLeadingZeros(matches | block_high_bits) >> 
3);
 }
 
-template <typename T, bool use_selection>
+template <bool use_selection>
 void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* 
selection,
                                        const uint32_t* hashes, const uint8_t* 
local_slots,
-                                       uint32_t* out_group_ids, int 
element_offset,
-                                       int element_multiplier) const {
-  const T* elements = reinterpret_cast<const T*>(blocks_->data()) + 
element_offset;
+                                       uint32_t* out_group_ids) const {
   if (log_blocks_ == 0) {
-    ARROW_DCHECK(sizeof(T) == sizeof(uint8_t));
     for (int i = 0; i < num_keys; ++i) {
       uint32_t id = use_selection ? selection[i] : i;
-      uint32_t group_id = blocks()[8 + local_slots[id]];
+      uint32_t group_id =
+          block_data(/*block_id=*/0,
+                     /*num_block_bytes=*/0)[bytes_status_in_block_ + 
local_slots[id]];
       out_group_ids[id] = group_id;
     }
   } else {
+    int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+    int num_groupid_bytes = num_groupid_bits / 8;
+    uint32_t group_id_mask = 
group_id_mask_from_num_groupid_bits(num_groupid_bits);
+    int num_block_bytes = 
num_block_bytes_from_num_groupid_bits(num_groupid_bits);
+
     for (int i = 0; i < num_keys; ++i) {
       uint32_t id = use_selection ? selection[i] : i;
       uint32_t hash = hashes[id];
-      int64_t pos =
-          (hash >> (bits_hash_ - log_blocks_)) * element_multiplier + 
local_slots[id];
-      uint32_t group_id = static_cast<uint32_t>(elements[pos]);
-      ARROW_DCHECK(group_id < num_inserted_ || num_inserted_ == 0);
+      uint32_t block_id = block_id_from_hash(hash, log_blocks_);
+      uint32_t group_id = *reinterpret_cast<const uint32_t*>(
+          block_data(block_id, num_block_bytes) + local_slots[id] * 
num_groupid_bytes +
+          bytes_status_in_block_);
+      group_id &= group_id_mask;

Review Comment:
   > So we always issue a 32-bit load but then we optionally mask if the actual 
group id width is smaller? Don't we risk reading past `block_data` bounds here?
   
   There will always be `padding_` (64) extra bytes at the buffer end.
   
   > (also, should we use an unaligned load? see the `SafeLoad` and 
`SafeLoadAs` utility functions)
   
   It seems so indeed, though I didn't change how the original code does it.
   
   I'll update later.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to