crepererum commented on code in PR #6288:
URL: https://github.com/apache/arrow-rs/pull/6288#discussion_r1746791154


##########
arrow-buffer/src/util/bit_mask.rs:
##########
@@ -32,33 +31,126 @@ pub fn set_bits(
 ) -> usize {
     let mut null_count = 0;
 
-    let mut bits_to_align = offset_write % 8;
-    if bits_to_align > 0 {
-        bits_to_align = std::cmp::min(len, 8 - bits_to_align);
+    let mut acc = 0;
+    while len > acc {
+        let (n, l) = set_upto_64bits(
+            write_data,
+            data,
+            offset_write + acc,
+            offset_read + acc,
+            len - acc,
+        );
+        null_count += n;
+        acc += l;
     }
-    let mut write_byte_index = ceil(offset_write + bits_to_align, 8);
-
-    // Set full bytes provided by bit chunk iterator (which iterates in 64 
bits at a time)
-    let chunks = BitChunks::new(data, offset_read + bits_to_align, len - 
bits_to_align);
-    chunks.iter().for_each(|chunk| {
-        null_count += chunk.count_zeros();
-        write_data[write_byte_index..write_byte_index + 
8].copy_from_slice(&chunk.to_le_bytes());
-        write_byte_index += 8;
-    });
-
-    // Set individual bits both to align write_data to a byte offset and the 
remainder bits not covered by the bit chunk iterator
-    let remainder_offset = len - chunks.remainder_len();
-    (0..bits_to_align)
-        .chain(remainder_offset..len)
-        .for_each(|i| {
-            if get_bit(data, offset_read + i) {
-                set_bit(write_data, offset_write + i);
+
+    null_count
+}
+
+#[inline]
+fn set_upto_64bits(
+    write_data: &mut [u8],
+    data: &[u8],
+    offset_write: usize,
+    offset_read: usize,
+    len: usize,
+) -> (usize, usize) {
+    let read_byte = offset_read / 8;
+    let read_shift = offset_read % 8;
+    let write_byte = offset_write / 8;
+    let write_shift = offset_write % 8;
+
+    if len >= 64 {
+        // SAFETY: chunk gets masked when necessary, so it is safe
+        let chunk = unsafe { read_bytes_to_u64(data, read_byte, 8) };
+        if read_shift == 0 {
+            if write_shift == 0 {
+                let len = 64;
+                let null_count = chunk.count_zeros() as usize;
+                write_u64_bytes(write_data, write_byte, chunk);
+                (null_count, len)
             } else {
-                null_count += 1;
+                let len = 64 - write_shift;
+                let chunk = chunk << write_shift;
+                let null_count = len - chunk.count_ones() as usize;
+                or_write_u64_bytes(write_data, write_byte, chunk);
+                (null_count, len)
             }
-        });
+        } else if write_shift == 0 {
+            let len = 64 - 8; // 56 bits so that write_shift == 0 for the next 
iteration
+            let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits 
mask
+            let null_count = len - chunk.count_ones() as usize;
+            write_u64_bytes(write_data, write_byte, chunk);
+            (null_count, len)
+        } else {
+            let len = 64 - std::cmp::max(read_shift, write_shift);
+            let chunk = (chunk >> read_shift) << write_shift;
+            let null_count = len - chunk.count_ones() as usize;
+            or_write_u64_bytes(write_data, write_byte, chunk);
+            (null_count, len)
+        }
+    } else if len == 1 {
+        let c = (unsafe { *data.as_ptr().add(read_byte) } >> read_shift) & 1;
+        let ptr = write_data.as_mut_ptr();
+        unsafe { *ptr.add(write_byte) |= c << write_shift };
+        ((c ^ 1) as usize, 1)
+    } else {
+        let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, 
write_shift));
+        let bytes = ceil(len + read_shift, 8);
+        // SAFETY: chunk gets masked, so it is safe
+        let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) };
+        let mask = u64::MAX >> (64 - len);
+        let chunk = (chunk >> read_shift) & mask;
+        let chunk = chunk << write_shift;
+        let null_count = len - chunk.count_ones() as usize;
+        let bytes = ceil(len + write_shift, 8);
+        let ptr = unsafe { write_data.as_mut_ptr().add(write_byte) };
+        for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) {
+            unsafe { *ptr.add(i) |= c };
+        }
+        (null_count, len)
+    }
+}
+
+/// # Safety
+/// The caller must ensure all arguments are within the valid range.
+/// The caller must be aware `8 - count` bytes in the returned value are 
uninitialized.
+#[inline]
+#[cfg(not(miri))]

Review Comment:
   For others reading this thread:
   
   The MIRI team concluded that for the conversion from `MaybeUninit` to `u64` 
via `assume_init`, it is irrelevant if the resulting `u64` is used or not. The 
important bit is if the memory was initialized or not and this is what we 
violated. And this "init" VS "uninit" distinction is also done by LLVM (Rust's 
`u64` is translated to LLVM `noundef i64`), so this is not even a MIRI 
question. The linked miri ticket provides some more discussion topics and a 
link to LLVM IR to illustrate that.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to