This is an automated email from the ASF dual-hosted git repository.

hcr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new cedcc7622 fix: misaligned vector loads for odd-length samples in batch 
amplitude and norm kernels (#1108)
cedcc7622 is described below

commit cedcc762286f8bc7c3ba7569a5deb4e5e15b8bba
Author: Vic Wen <[email protected]>
AuthorDate: Wed Mar 4 14:41:58 2026 +0800

    fix: misaligned vector loads for odd-length samples in batch amplitude and 
norm kernels (#1108)
---
 qdp/qdp-kernels/src/amplitude.cu          |  73 ++++++++----
 qdp/qdp-kernels/tests/amplitude_encode.rs | 185 +++++++++++++++++++++++++++++-
 2 files changed, 235 insertions(+), 23 deletions(-)

diff --git a/qdp/qdp-kernels/src/amplitude.cu b/qdp/qdp-kernels/src/amplitude.cu
index 67a863eaa..224e6bcd6 100644
--- a/qdp/qdp-kernels/src/amplitude.cu
+++ b/qdp/qdp-kernels/src/amplitude.cu
@@ -20,6 +20,7 @@
 #include <cuComplex.h>
 #include <vector_types.h>
 #include <math.h>
+#include <stdint.h>
 #include "kernel_config.h"
 
 __global__ void amplitude_encode_kernel(
@@ -231,10 +232,10 @@ int launch_amplitude_encode_f32(
 /// - state_batch: [sample0_state | sample1_state | ... | sampleN_state]
 ///
 /// Optimizations:
-/// 1. Vectorized double2 loads for 128-bit memory transactions
+/// 1. Vectorized double2 loads for 128-bit memory transactions when aligned
 /// 2. Grid-stride loop for arbitrary batch sizes
 /// 3. Coalesced memory access within warps
-/// 4. Minimized register pressure
+/// 4. Scalar fallback for misaligned sample bases and odd tails
 __global__ void amplitude_encode_batch_kernel(
     const double* __restrict__ input_batch,
     cuDoubleComplex* __restrict__ state_batch,
@@ -264,17 +265,21 @@ __global__ void amplitude_encode_batch_kernel(
         // Load inverse norm (cached by L1)
         const double inv_norm = inv_norms[sample_idx];
 
-        // Vectorized load: read 2 doubles as double2 for 128-bit transaction
         double v1, v2;
-        if (elem_offset + 1 < input_len) {
-            // Aligned vectorized load
-            const double2 vec_data = __ldg(reinterpret_cast<const 
double2*>(input_batch + input_base) + elem_pair);
+        const double* sample_input = input_batch + input_base;
+        const bool sample_input_aligned =
+            (reinterpret_cast<uintptr_t>(sample_input) & (alignof(double2) - 
1)) == 0;
+
+        if (sample_input_aligned && elem_offset + 1 < input_len) {
+            const double2 vec_data =
+                __ldg(reinterpret_cast<const double2*>(sample_input) + 
elem_pair);
             v1 = vec_data.x;
             v2 = vec_data.y;
         } else if (elem_offset < input_len) {
-            // Edge case: single element load
-            v1 = __ldg(input_batch + input_base + elem_offset);
-            v2 = 0.0;
+            v1 = __ldg(sample_input + elem_offset);
+            v2 = (elem_offset + 1 < input_len)
+                ? __ldg(sample_input + elem_offset + 1)
+                : 0.0;
         } else {
             // Padding region
             v1 = v2 = 0.0;
@@ -432,20 +437,33 @@ __global__ void l2_norm_batch_kernel(
 
     const size_t vec_idx = block_in_sample * blockDim.x + threadIdx.x;
     const size_t stride = blockDim.x * blocks_per_sample;
+    const double* sample_input = input_batch + base;
+    const bool sample_input_aligned =
+        (reinterpret_cast<uintptr_t>(sample_input) & (alignof(double2) - 1)) 
== 0;
 
     double local_sum = 0.0;
 
     size_t vec_offset = vec_idx;
     size_t offset = vec_offset * 2;
-    while (offset + 1 < sample_len) {
-        const double2 v = __ldg(reinterpret_cast<const double2*>(input_batch + 
base) + vec_offset);
-        local_sum += v.x * v.x + v.y * v.y;
-        vec_offset += stride;
-        offset = vec_offset * 2;
+    if (sample_input_aligned) {
+        while (offset + 1 < sample_len) {
+            const double2 v = __ldg(reinterpret_cast<const 
double2*>(sample_input) + vec_offset);
+            local_sum += v.x * v.x + v.y * v.y;
+            vec_offset += stride;
+            offset = vec_offset * 2;
+        }
+    } else {
+        while (offset + 1 < sample_len) {
+            const double v1 = __ldg(sample_input + offset);
+            const double v2 = __ldg(sample_input + offset + 1);
+            local_sum += v1 * v1 + v2 * v2;
+            vec_offset += stride;
+            offset = vec_offset * 2;
+        }
     }
 
     if (offset < sample_len) {
-        const double v = __ldg(input_batch + base + offset);
+        const double v = __ldg(sample_input + offset);
         local_sum += v * v;
     }
 
@@ -472,20 +490,33 @@ __global__ void l2_norm_batch_kernel_f32(
 
     const size_t vec_idx = block_in_sample * blockDim.x + threadIdx.x;
     const size_t stride = blockDim.x * blocks_per_sample;
+    const float* sample_input = input_batch + base;
+    const bool sample_input_aligned =
+        (reinterpret_cast<uintptr_t>(sample_input) & (alignof(float2) - 1)) == 
0;
 
     float local_sum = 0.0f;
 
     size_t vec_offset = vec_idx;
     size_t offset = vec_offset * 2;
-    while (offset + 1 < sample_len) {
-        const float2 v = __ldg(reinterpret_cast<const float2*>(input_batch + 
base) + vec_offset);
-        local_sum += v.x * v.x + v.y * v.y;
-        vec_offset += stride;
-        offset = vec_offset * 2;
+    if (sample_input_aligned) {
+        while (offset + 1 < sample_len) {
+            const float2 v = __ldg(reinterpret_cast<const 
float2*>(sample_input) + vec_offset);
+            local_sum += v.x * v.x + v.y * v.y;
+            vec_offset += stride;
+            offset = vec_offset * 2;
+        }
+    } else {
+        while (offset + 1 < sample_len) {
+            const float v1 = __ldg(sample_input + offset);
+            const float v2 = __ldg(sample_input + offset + 1);
+            local_sum += v1 * v1 + v2 * v2;
+            vec_offset += stride;
+            offset = vec_offset * 2;
+        }
     }
 
     if (offset < sample_len) {
-        const float v = __ldg(input_batch + base + offset);
+        const float v = __ldg(sample_input + offset);
         local_sum += v * v;
     }
 
diff --git a/qdp/qdp-kernels/tests/amplitude_encode.rs 
b/qdp/qdp-kernels/tests/amplitude_encode.rs
index 579a4e7ec..53f91505f 100644
--- a/qdp/qdp-kernels/tests/amplitude_encode.rs
+++ b/qdp/qdp-kernels/tests/amplitude_encode.rs
@@ -25,13 +25,50 @@
 use cudarc::driver::{CudaDevice, DevicePtr, DevicePtrMut};
 #[cfg(target_os = "linux")]
 use qdp_kernels::{
-    CuComplex, CuDoubleComplex, launch_amplitude_encode, 
launch_amplitude_encode_f32,
-    launch_l2_norm, launch_l2_norm_batch, launch_l2_norm_batch_f32, 
launch_l2_norm_f32,
+    CuComplex, CuDoubleComplex, launch_amplitude_encode, 
launch_amplitude_encode_batch,
+    launch_amplitude_encode_f32, launch_l2_norm, launch_l2_norm_batch, 
launch_l2_norm_batch_f32,
+    launch_l2_norm_f32,
 };
 
 const EPSILON: f64 = 1e-10;
 const EPSILON_F32: f32 = 1e-5;
 
+#[cfg(target_os = "linux")]
+fn assert_batch_state_matches_f64(
+    state_h: &[CuDoubleComplex],
+    input: &[f64],
+    num_samples: usize,
+    sample_len: usize,
+    state_len: usize,
+) {
+    for sample_idx in 0..num_samples {
+        let sample = &input[sample_idx * sample_len..(sample_idx + 1) * 
sample_len];
+        let norm = sample.iter().map(|v| v * v).sum::<f64>().sqrt();
+        for elem_idx in 0..state_len {
+            let expected = if elem_idx < sample_len {
+                sample[elem_idx] / norm
+            } else {
+                0.0
+            };
+            let actual = state_h[sample_idx * state_len + elem_idx];
+            assert!(
+                (actual.x - expected).abs() < EPSILON,
+                "sample {} element {} expected {}, got {}",
+                sample_idx,
+                elem_idx,
+                expected,
+                actual.x
+            );
+            assert!(
+                actual.y.abs() < EPSILON,
+                "sample {} element {} imaginary should be 0",
+                sample_idx,
+                elem_idx
+            );
+        }
+    }
+}
+
 #[test]
 #[cfg(target_os = "linux")]
 fn test_amplitude_encode_basic() {
@@ -574,6 +611,52 @@ fn test_amplitude_encode_small_input_large_state() {
     println!("PASS: Small input with large state padding works correctly");
 }
 
+#[test]
+#[cfg(target_os = "linux")]
+fn test_amplitude_encode_batch_odd_sample_length_handles_misaligned_samples() {
+    println!("Testing batch amplitude encoding with odd sample length 
(float64)...");
+
+    let device = match CudaDevice::new(0) {
+        Ok(d) => d,
+        Err(_) => {
+            println!("SKIP: No CUDA device available");
+            return;
+        }
+    };
+
+    let num_samples = 2usize;
+    let sample_len = 3usize;
+    let state_len = 4usize;
+    let input: Vec<f64> = vec![1.0, 2.0, 2.0, 2.0, 1.0, 2.0];
+    let inv_norms: Vec<f64> = input
+        .chunks(sample_len)
+        .map(|sample| 1.0 / sample.iter().map(|v| v * v).sum::<f64>().sqrt())
+        .collect();
+
+    let input_d = device.htod_sync_copy(input.as_slice()).unwrap();
+    let inv_norms_d = device.htod_sync_copy(inv_norms.as_slice()).unwrap();
+    let mut state_d = device
+        .alloc_zeros::<CuDoubleComplex>(num_samples * state_len)
+        .unwrap();
+
+    let result = unsafe {
+        launch_amplitude_encode_batch(
+            *input_d.device_ptr() as *const f64,
+            *state_d.device_ptr_mut() as *mut std::ffi::c_void,
+            *inv_norms_d.device_ptr() as *const f64,
+            num_samples,
+            sample_len,
+            state_len,
+            std::ptr::null_mut(),
+        )
+    };
+
+    assert_eq!(result, 0, "Batch kernel launch should succeed");
+
+    let state_h = device.dtoh_sync_copy(&state_d).unwrap();
+    assert_batch_state_matches_f64(&state_h, &input, num_samples, sample_len, 
state_len);
+}
+
 #[test]
 #[cfg(target_os = "linux")]
 fn test_l2_norm_single_kernel() {
@@ -671,6 +754,55 @@ fn test_l2_norm_batch_kernel_stream() {
     println!("PASS: Batched norm reduction on stream matches CPU");
 }
 
+#[test]
+#[cfg(target_os = "linux")]
+fn test_l2_norm_batch_kernel_odd_sample_len() {
+    println!("Testing batched L2 norm reduction with odd sample length 
(float64)...");
+
+    let device = match CudaDevice::new(0) {
+        Ok(d) => d,
+        Err(_) => {
+            println!("SKIP: No CUDA device available");
+            return;
+        }
+    };
+
+    let sample_len = 3usize;
+    let num_samples = 2usize;
+    let input: Vec<f64> = vec![1.0, 2.0, 2.0, 2.0, 1.0, 2.0];
+    let expected: Vec<f64> = input
+        .chunks(sample_len)
+        .map(|chunk| 1.0 / chunk.iter().map(|v| v * v).sum::<f64>().sqrt())
+        .collect();
+
+    let input_d = device.htod_sync_copy(input.as_slice()).unwrap();
+    let mut norms_d = device.alloc_zeros::<f64>(num_samples).unwrap();
+
+    let status = unsafe {
+        launch_l2_norm_batch(
+            *input_d.device_ptr() as *const f64,
+            num_samples,
+            sample_len,
+            *norms_d.device_ptr_mut() as *mut f64,
+            std::ptr::null_mut(),
+        )
+    };
+
+    assert_eq!(status, 0, "Batch norm kernel should succeed");
+    device.synchronize().unwrap();
+
+    let norms_h = device.dtoh_sync_copy(&norms_d).unwrap();
+    for (i, (got, expect)) in norms_h.iter().zip(expected.iter()).enumerate() {
+        assert!(
+            (got - expect).abs() < EPSILON,
+            "Sample {} inv norm mismatch: expected {}, got {}",
+            i,
+            expect,
+            got
+        );
+    }
+}
+
 #[test]
 #[cfg(target_os = "linux")]
 fn test_l2_norm_batch_kernel_zero_num_samples() {
@@ -840,6 +972,55 @@ fn test_l2_norm_batch_kernel_f32() {
     println!("PASS: Batched norm reduction (float32) matches CPU");
 }
 
+#[test]
+#[cfg(target_os = "linux")]
+fn test_l2_norm_batch_kernel_f32_odd_sample_len() {
+    println!("Testing batched L2 norm reduction with odd sample length 
(float32)...");
+
+    let device = match CudaDevice::new(0) {
+        Ok(d) => d,
+        Err(_) => {
+            println!("SKIP: No CUDA device available");
+            return;
+        }
+    };
+
+    let sample_len = 3usize;
+    let num_samples = 2usize;
+    let input: Vec<f32> = vec![1.0, 2.0, 2.0, 2.0, 1.0, 2.0];
+    let expected: Vec<f32> = input
+        .chunks(sample_len)
+        .map(|chunk| 1.0 / chunk.iter().map(|v| v * v).sum::<f32>().sqrt())
+        .collect();
+
+    let input_d = device.htod_sync_copy(input.as_slice()).unwrap();
+    let mut norms_d = device.alloc_zeros::<f32>(num_samples).unwrap();
+
+    let status = unsafe {
+        launch_l2_norm_batch_f32(
+            *input_d.device_ptr() as *const f32,
+            num_samples,
+            sample_len,
+            *norms_d.device_ptr_mut() as *mut f32,
+            std::ptr::null_mut(),
+        )
+    };
+
+    assert_eq!(status, 0, "Batch norm f32 kernel should succeed");
+    device.synchronize().unwrap();
+
+    let norms_h = device.dtoh_sync_copy(&norms_d).unwrap();
+    for (i, (got, expect)) in norms_h.iter().zip(expected.iter()).enumerate() {
+        assert!(
+            (got - expect).abs() < EPSILON_F32,
+            "Sample {} inv norm mismatch: expected {}, got {}",
+            i,
+            expect,
+            got
+        );
+    }
+}
+
 #[test]
 #[cfg(target_os = "linux")]
 fn test_l2_norm_batch_kernel_zero_num_samples_f32() {

Reply via email to