This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git

commit d67cbe902d9391804fe216bef26b520cc6764d8d
Author: Ping <[email protected]>
AuthorDate: Thu Dec 11 20:45:08 2025 +0800

    [QDP] GPU/Cuda kernel Optimizations (#706)
    
    * GPU/Cuda kernel Optimizations
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix
    
    Signed-off-by: 400Ping <[email protected]>
    
    ---------
    
    Signed-off-by: 400Ping <[email protected]>
---
 qdp/qdp-core/src/gpu/encodings/amplitude.rs | 124 ++++++++++++---
 qdp/qdp-kernels/src/amplitude.cu            | 233 +++++++++++++++++++++++++++-
 qdp/qdp-kernels/src/lib.rs                  |  52 ++++++-
 qdp/qdp-kernels/tests/amplitude_encode.rs   | 123 +++++++++++++--
 4 files changed, 498 insertions(+), 34 deletions(-)

diff --git a/qdp/qdp-core/src/gpu/encodings/amplitude.rs 
b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
index ff1490c48..844f6d1e8 100644
--- a/qdp/qdp-core/src/gpu/encodings/amplitude.rs
+++ b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
@@ -27,9 +27,14 @@ use super::QuantumEncoder;
 #[cfg(target_os = "linux")]
 use std::ffi::c_void;
 #[cfg(target_os = "linux")]
-use cudarc::driver::DevicePtr;
+use cudarc::driver::{DevicePtr, DevicePtrMut};
 #[cfg(target_os = "linux")]
-use qdp_kernels::{launch_amplitude_encode, launch_amplitude_encode_batch};
+use qdp_kernels::{
+    launch_amplitude_encode,
+    launch_amplitude_encode_batch,
+    launch_l2_norm,
+    launch_l2_norm_batch,
+};
 #[cfg(target_os = "linux")]
 use crate::gpu::memory::{ensure_device_memory_available, map_allocation_error};
 
@@ -50,7 +55,6 @@ impl QuantumEncoder for AmplitudeEncoder {
     ) -> Result<GpuStateVector> {
         // Validate qubits (max 30 = 16GB GPU memory)
         Preprocessor::validate_input(host_data, num_qubits)?;
-        let norm = Preprocessor::calculate_l2_norm(host_data)?;
         let state_len = 1 << num_qubits;
 
         #[cfg(target_os = "linux")]
@@ -65,6 +69,7 @@ impl QuantumEncoder for AmplitudeEncoder {
             // For small data (< 1MB), use synchronous path to avoid stream 
overhead
             // For large data, use dual-stream async pipeline for maximum 
throughput
             const ASYNC_THRESHOLD: usize = 1024 * 1024 / 
std::mem::size_of::<f64>(); // 1MB threshold
+            const GPU_NORM_THRESHOLD: usize = 4096; // heuristic: amortize 
kernel launch
 
             if host_data.len() < ASYNC_THRESHOLD {
                 // Synchronous path for small data (avoids stream overhead)
@@ -82,6 +87,18 @@ impl QuantumEncoder for AmplitudeEncoder {
                         ))?
                 };
 
+                // GPU-accelerated norm for medium+ inputs, CPU fallback for 
tiny payloads
+                let inv_norm = if host_data.len() >= GPU_NORM_THRESHOLD {
+                    Self::calculate_inv_norm_gpu(
+                        _device,
+                        *input_slice.device_ptr() as *const f64,
+                        host_data.len(),
+                    )?
+                } else {
+                    let norm = Preprocessor::calculate_l2_norm(host_data)?;
+                    1.0 / norm
+                };
+
                 let ret = {
                     crate::profile_scope!("GPU::KernelLaunch");
                     unsafe {
@@ -90,7 +107,7 @@ impl QuantumEncoder for AmplitudeEncoder {
                             state_vector.ptr() as *mut c_void,
                             host_data.len(),
                             state_len,
-                            norm,
+                            inv_norm,
                             std::ptr::null_mut(), // default stream
                         )
                     }
@@ -121,7 +138,9 @@ impl QuantumEncoder for AmplitudeEncoder {
                 }
             } else {
                 // Async Pipeline path for large data
-                Self::encode_async_pipeline(_device, host_data, num_qubits, 
state_len, norm, &state_vector)?;
+                let norm = Preprocessor::calculate_l2_norm(host_data)?;
+                let inv_norm = 1.0 / norm;
+                Self::encode_async_pipeline(_device, host_data, num_qubits, 
state_len, inv_norm, &state_vector)?;
             }
 
             Ok(state_vector)
@@ -150,12 +169,6 @@ impl QuantumEncoder for AmplitudeEncoder {
 
         let state_len = 1 << num_qubits;
 
-        // Calculate L2 norms using shared preprocessor (parallelized)
-        let norms = Preprocessor::calculate_batch_l2_norms(batch_data, 
num_samples, sample_size)?;
-
-        // Convert to inverse norms
-        let inv_norms: Vec<f64> = norms.iter().map(|n| 1.0 / n).collect();
-
         // Allocate single large GPU buffer for all states
         let batch_state_vector = {
             crate::profile_scope!("GPU::AllocBatch");
@@ -171,15 +184,46 @@ impl QuantumEncoder for AmplitudeEncoder {
                 ))?
         };
 
-        // Upload inverse norms to GPU
+        // Compute inverse norms on GPU using warp-reduced kernel
         let inv_norms_gpu = {
-            crate::profile_scope!("GPU::H2D_Norms");
-            device.htod_sync_copy(&inv_norms)
+            crate::profile_scope!("GPU::BatchNormKernel");
+            let mut buffer = device.alloc_zeros::<f64>(num_samples)
                 .map_err(|e| MahoutError::MemoryAllocation(
-                    format!("Failed to upload norms: {:?}", e)
-                ))?
+                    format!("Failed to allocate norm buffer: {:?}", e)
+                ))?;
+
+            let ret = unsafe {
+                launch_l2_norm_batch(
+                    *input_batch_gpu.device_ptr() as *const f64,
+                    num_samples,
+                    sample_size,
+                    *buffer.device_ptr_mut() as *mut f64,
+                    std::ptr::null_mut(), // default stream
+                )
+            };
+
+            if ret != 0 {
+                return Err(MahoutError::KernelLaunch(
+                    format!("Norm reduction kernel failed: {} ({})", ret, 
cuda_error_to_string(ret))
+                ));
+            }
+
+            buffer
         };
 
+        // Validate norms on host to catch zero or NaN samples early
+        {
+            crate::profile_scope!("GPU::NormValidation");
+            let host_inv_norms = device.dtoh_sync_copy(&inv_norms_gpu)
+                .map_err(|e| MahoutError::Cuda(format!("Failed to copy norms 
to host: {:?}", e)))?;
+
+            if host_inv_norms.iter().any(|v| !v.is_finite() || *v == 0.0) {
+                return Err(MahoutError::InvalidInput(
+                    "One or more samples have zero or invalid norm".to_string()
+                ));
+            }
+        }
+
         // Launch batch kernel
         {
             crate::profile_scope!("GPU::BatchKernelLaunch");
@@ -236,7 +280,7 @@ impl AmplitudeEncoder {
         host_data: &[f64],
         _num_qubits: usize,
         state_len: usize,
-        norm: f64,
+        inv_norm: f64,
         state_vector: &GpuStateVector,
     ) -> Result<()> {
         // Use generic pipeline infrastructure
@@ -257,7 +301,7 @@ impl AmplitudeEncoder {
                     state_ptr_offset,
                     chunk_len,
                     state_len,
-                    norm,
+                    inv_norm,
                     stream.stream as *mut c_void,
                 )
             };
@@ -333,6 +377,50 @@ impl AmplitudeEncoder {
     }
 }
 
+impl AmplitudeEncoder {
+    /// Compute inverse L2 norm on GPU using the reduction kernel.
+    #[cfg(target_os = "linux")]
+    fn calculate_inv_norm_gpu(
+        device: &Arc<CudaDevice>,
+        input_ptr: *const f64,
+        len: usize,
+    ) -> Result<f64> {
+        crate::profile_scope!("GPU::NormSingle");
+
+        let mut norm_buffer = device.alloc_zeros::<f64>(1)
+            .map_err(|e| MahoutError::MemoryAllocation(
+                format!("Failed to allocate norm buffer: {:?}", e)
+            ))?;
+
+        let ret = unsafe {
+            launch_l2_norm(
+                input_ptr,
+                len,
+                *norm_buffer.device_ptr_mut() as *mut f64,
+                std::ptr::null_mut(), // default stream
+            )
+        };
+
+        if ret != 0 {
+            return Err(MahoutError::KernelLaunch(
+                format!("Norm kernel failed: {} ({})", ret, 
cuda_error_to_string(ret))
+            ));
+        }
+
+        let inv_norm_host = device.dtoh_sync_copy(&norm_buffer)
+            .map_err(|e| MahoutError::Cuda(format!("Failed to copy norm to 
host: {:?}", e)))?;
+
+        let inv_norm = inv_norm_host.get(0).copied().unwrap_or(0.0);
+        if inv_norm == 0.0 || !inv_norm.is_finite() {
+            return Err(MahoutError::InvalidInput(
+                "Input data has zero norm".to_string()
+            ));
+        }
+
+        Ok(inv_norm)
+    }
+}
+
 /// Convert CUDA error code to human-readable string
 #[cfg(target_os = "linux")]
 fn cuda_error_to_string(code: i32) -> &'static str {
diff --git a/qdp/qdp-kernels/src/amplitude.cu b/qdp/qdp-kernels/src/amplitude.cu
index 4dd80fb0b..0e679e2b9 100644
--- a/qdp/qdp-kernels/src/amplitude.cu
+++ b/qdp/qdp-kernels/src/amplitude.cu
@@ -19,6 +19,7 @@
 #include <cuda_runtime.h>
 #include <cuComplex.h>
 #include <vector_types.h>
+#include <math.h>
 
 __global__ void amplitude_encode_kernel(
     const double* __restrict__ input,
@@ -65,6 +66,34 @@ __global__ void amplitude_encode_kernel(
     }
 }
 
+// Warp-level reduction for sum using shuffle instructions
+__device__ __forceinline__ double warp_reduce_sum(double val) {
+    for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
+        val += __shfl_down_sync(0xffffffff, val, offset);
+    }
+    return val;
+}
+
+// Block-level reduction built on top of warp reduction
+__device__ __forceinline__ double block_reduce_sum(double val) {
+    __shared__ double shared[32]; // supports up to 1024 threads (32 warps)
+    int lane = threadIdx.x & (warpSize - 1);
+    int warp_id = threadIdx.x >> 5;
+
+    val = warp_reduce_sum(val);
+    if (lane == 0) {
+        shared[warp_id] = val;
+    }
+    __syncthreads();
+
+    // Only first warp participates in final reduction
+    val = (threadIdx.x < (blockDim.x + warpSize - 1) / warpSize) ? 
shared[lane] : 0.0;
+    if (warp_id == 0) {
+        val = warp_reduce_sum(val);
+    }
+    return val;
+}
+
 extern "C" {
 
 /// Launch amplitude encoding kernel
@@ -74,7 +103,7 @@ extern "C" {
 /// * state_d - Device pointer to output state vector
 /// * input_len - Number of input elements
 /// * state_len - Target state vector size (2^num_qubits)
-/// * norm - L2 norm computed by host
+/// * inv_norm - Reciprocal L2 norm (1 / ||input||)
 /// * stream - CUDA stream for async execution (nullptr = default stream)
 ///
 /// # Returns
@@ -84,15 +113,13 @@ int launch_amplitude_encode(
     void* state_d,
     size_t input_len,
     size_t state_len,
-    double norm,
+    double inv_norm,
     cudaStream_t stream
 ) {
-    if (norm <= 0.0) {
+    if (inv_norm <= 0.0 || !isfinite(inv_norm)) {
         return cudaErrorInvalidValue;
     }
 
-    double inv_norm = 1.0 / norm;
-
     cuDoubleComplex* state_complex_d = static_cast<cuDoubleComplex*>(state_d);
 
     const int blockSize = 256;
@@ -231,6 +258,202 @@ int launch_amplitude_encode_batch(
     return (int)cudaGetLastError();
 }
 
+/// Kernel: accumulate L2 norm using coalesced vectorized loads.
+/// Each block atomically adds its partial sum to the output accumulator.
+__global__ void l2_norm_kernel(
+    const double* __restrict__ input,
+    size_t input_len,
+    double* __restrict__ out_accum
+) {
+    // Vectorized double2 loads for bandwidth and coalescing
+    const size_t vec_idx = blockIdx.x * blockDim.x + threadIdx.x;
+    const size_t stride = gridDim.x * blockDim.x;
+
+    double local_sum = 0.0;
+
+    // Process two elements per iteration via double2
+    size_t vec_offset = vec_idx;
+    size_t offset = vec_offset * 2;
+    while (offset + 1 < input_len) {
+        const double2 v = __ldg(reinterpret_cast<const double2*>(input) + 
vec_offset);
+        local_sum += v.x * v.x + v.y * v.y;
+        vec_offset += stride;
+        offset = vec_offset * 2;
+    }
+
+    // Handle tail element if input_len is odd
+    if (offset < input_len) {
+        const double v = __ldg(input + offset);
+        local_sum += v * v;
+    }
+
+    const double block_sum = block_reduce_sum(local_sum);
+    if (threadIdx.x == 0) {
+        atomicAdd(out_accum, block_sum);
+    }
+}
+
+/// Kernel: accumulate L2 norms for a batch.
+/// Grid is organized as (blocks_per_sample * num_samples) blocks.
+__global__ void l2_norm_batch_kernel(
+    const double* __restrict__ input_batch,
+    size_t num_samples,
+    size_t sample_len,
+    size_t blocks_per_sample,
+    double* __restrict__ out_norms
+) {
+    const size_t sample_idx = blockIdx.x / blocks_per_sample;
+    if (sample_idx >= num_samples) return;
+
+    const size_t block_in_sample = blockIdx.x % blocks_per_sample;
+    const size_t base = sample_idx * sample_len;
+
+    const size_t vec_idx = block_in_sample * blockDim.x + threadIdx.x;
+    const size_t stride = blockDim.x * blocks_per_sample;
+
+    double local_sum = 0.0;
+
+    size_t vec_offset = vec_idx;
+    size_t offset = vec_offset * 2;
+    while (offset + 1 < sample_len) {
+        const double2 v = __ldg(reinterpret_cast<const double2*>(input_batch + 
base) + vec_offset);
+        local_sum += v.x * v.x + v.y * v.y;
+        vec_offset += stride;
+        offset = vec_offset * 2;
+    }
+
+    if (offset < sample_len) {
+        const double v = __ldg(input_batch + base + offset);
+        local_sum += v * v;
+    }
+
+    const double block_sum = block_reduce_sum(local_sum);
+    if (threadIdx.x == 0) {
+        atomicAdd(out_norms + sample_idx, block_sum);
+    }
+}
+
+/// Kernel: converts accumulated sum-of-squares into inverse norms.
+__global__ void finalize_inv_norm_kernel(
+    double* __restrict__ norms,
+    size_t count
+) {
+    const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx >= count) return;
+
+    double sum = norms[idx];
+    // Guard against zero or NaN to avoid inf propagation
+    if (sum <= 0.0 || !isfinite(sum)) {
+        norms[idx] = 0.0;
+    } else {
+        norms[idx] = rsqrt(sum);
+    }
+}
+
+/// Launch L2 norm reduction for a single vector.
+/// Writes the inverse norm (1 / ||x||) into `inv_norm_out_d`.
+int launch_l2_norm(
+    const double* input_d,
+    size_t input_len,
+    double* inv_norm_out_d,
+    cudaStream_t stream
+) {
+    if (input_len == 0) {
+        return cudaErrorInvalidValue;
+    }
+
+    cudaError_t memset_status = cudaMemsetAsync(
+        inv_norm_out_d,
+        0,
+        sizeof(double),
+        stream
+    );
+    if (memset_status != cudaSuccess) {
+        return memset_status;
+    }
+
+    const int blockSize = 256;
+    const size_t elements_per_block = blockSize * 2; // double2 per thread
+    size_t gridSize = (input_len + elements_per_block - 1) / 
elements_per_block;
+    gridSize = (gridSize == 0) ? 1 : gridSize;
+    const size_t maxBlocks = 4096;
+    if (gridSize > maxBlocks) gridSize = maxBlocks;
+
+    l2_norm_kernel<<<gridSize, blockSize, 0, stream>>>(
+        input_d,
+        input_len,
+        inv_norm_out_d
+    );
+
+    // Finalize: convert accumulated sum to inverse norm
+    finalize_inv_norm_kernel<<<1, 32, 0, stream>>>(
+        inv_norm_out_d,
+        1
+    );
+
+    return (int)cudaGetLastError();
+}
+
+/// Launch L2 norm reduction for a batch of vectors.
+/// Writes inverse norms for each sample into `inv_norms_out_d`.
+int launch_l2_norm_batch(
+    const double* input_batch_d,
+    size_t num_samples,
+    size_t sample_len,
+    double* inv_norms_out_d,
+    cudaStream_t stream
+) {
+    if (num_samples == 0 || sample_len == 0) {
+        return cudaErrorInvalidValue;
+    }
+
+    cudaError_t memset_status = cudaMemsetAsync(
+        inv_norms_out_d,
+        0,
+        num_samples * sizeof(double),
+        stream
+    );
+    if (memset_status != cudaSuccess) {
+        return memset_status;
+    }
+
+    const int blockSize = 256;
+    const size_t elements_per_block = blockSize * 2; // double2 per thread
+    size_t blocks_per_sample = (sample_len + elements_per_block - 1) / 
elements_per_block;
+    const size_t max_blocks_per_sample = 32;
+    if (blocks_per_sample == 0) blocks_per_sample = 1;
+    if (blocks_per_sample > max_blocks_per_sample) {
+        blocks_per_sample = max_blocks_per_sample;
+    }
+
+    size_t gridSize = num_samples * blocks_per_sample;
+    const size_t max_grid = 65535; // CUDA grid dimension limit for 1D launch
+    if (gridSize > max_grid) {
+        blocks_per_sample = max_grid / num_samples;
+        if (blocks_per_sample == 0) {
+            blocks_per_sample = 1;
+        }
+        gridSize = num_samples * blocks_per_sample;
+    }
+
+    l2_norm_batch_kernel<<<gridSize, blockSize, 0, stream>>>(
+        input_batch_d,
+        num_samples,
+        sample_len,
+        blocks_per_sample,
+        inv_norms_out_d
+    );
+
+    const int finalizeBlock = 256;
+    const int finalizeGrid = (num_samples + finalizeBlock - 1) / finalizeBlock;
+    finalize_inv_norm_kernel<<<finalizeGrid, finalizeBlock, 0, stream>>>(
+        inv_norms_out_d,
+        num_samples
+    );
+
+    return (int)cudaGetLastError();
+}
+
 // TODO: Future encoding methods:
 // - launch_angle_encode (angle encoding)
 // - launch_basis_encode (basis encoding)
diff --git a/qdp/qdp-kernels/src/lib.rs b/qdp/qdp-kernels/src/lib.rs
index c1bb07544..f2ce2ad29 100644
--- a/qdp/qdp-kernels/src/lib.rs
+++ b/qdp/qdp-kernels/src/lib.rs
@@ -49,7 +49,7 @@ unsafe extern "C" {
         state_d: *mut c_void,
         input_len: usize,
         state_len: usize,
-        norm: f64,
+        inv_norm: f64,
         stream: *mut c_void,
     ) -> i32;
 
@@ -68,6 +68,31 @@ unsafe extern "C" {
         stream: *mut c_void,
     ) -> i32;
 
+    /// Launch L2 norm reduction (returns inverse norm)
+    /// Returns CUDA error code (0 = success)
+    ///
+    /// # Safety
+    /// Pointers must reference valid device memory on the provided stream.
+    pub fn launch_l2_norm(
+        input_d: *const f64,
+        input_len: usize,
+        inv_norm_out_d: *mut f64,
+        stream: *mut c_void,
+    ) -> i32;
+
+    /// Launch batched L2 norm reduction (returns inverse norms per sample)
+    /// Returns CUDA error code (0 = success)
+    ///
+    /// # Safety
+    /// Pointers must reference valid device memory on the provided stream.
+    pub fn launch_l2_norm_batch(
+        input_batch_d: *const f64,
+        num_samples: usize,
+        sample_len: usize,
+        inv_norms_out_d: *mut f64,
+        stream: *mut c_void,
+    ) -> i32;
+
     // TODO: launch_angle_encode, launch_basis_encode
 }
 
@@ -79,8 +104,31 @@ pub extern "C" fn launch_amplitude_encode(
     _state_d: *mut c_void,
     _input_len: usize,
     _state_len: usize,
-    _norm: f64,
+    _inv_norm: f64,
     _stream: *mut c_void,
 ) -> i32 {
     999 // Error: CUDA unavailable
 }
+
+#[cfg(not(target_os = "linux"))]
+#[unsafe(no_mangle)]
+pub extern "C" fn launch_l2_norm(
+    _input_d: *const f64,
+    _input_len: usize,
+    _inv_norm_out_d: *mut f64,
+    _stream: *mut c_void,
+) -> i32 {
+    999
+}
+
+#[cfg(not(target_os = "linux"))]
+#[unsafe(no_mangle)]
+pub extern "C" fn launch_l2_norm_batch(
+    _input_batch_d: *const f64,
+    _num_samples: usize,
+    _sample_len: usize,
+    _inv_norms_out_d: *mut f64,
+    _stream: *mut c_void,
+) -> i32 {
+    999
+}
diff --git a/qdp/qdp-kernels/tests/amplitude_encode.rs 
b/qdp/qdp-kernels/tests/amplitude_encode.rs
index 2ac125c3e..e290d550c 100644
--- a/qdp/qdp-kernels/tests/amplitude_encode.rs
+++ b/qdp/qdp-kernels/tests/amplitude_encode.rs
@@ -19,7 +19,7 @@
 #[cfg(target_os = "linux")]
 use cudarc::driver::{CudaDevice, DevicePtr, DevicePtrMut};
 #[cfg(target_os = "linux")]
-use qdp_kernels::{CuDoubleComplex, launch_amplitude_encode};
+use qdp_kernels::{CuDoubleComplex, launch_amplitude_encode, launch_l2_norm, 
launch_l2_norm_batch};
 
 const EPSILON: f64 = 1e-10;
 
@@ -40,6 +40,7 @@ fn test_amplitude_encode_basic() {
     // Test input: [3.0, 4.0] -> normalized to [0.6, 0.8]
     let input = vec![3.0, 4.0];
     let norm = (3.0_f64.powi(2) + 4.0_f64.powi(2)).sqrt(); // 5.0
+    let inv_norm = 1.0 / norm;
     let state_len = 4; // 2 qubits
 
     // Allocate device memory
@@ -53,7 +54,7 @@ fn test_amplitude_encode_basic() {
             *state_d.device_ptr_mut() as *mut std::ffi::c_void,
             input.len(),
             state_len,
-            norm,
+            inv_norm,
             std::ptr::null_mut(),
         )
     };
@@ -109,6 +110,7 @@ fn test_amplitude_encode_power_of_two() {
     // Test with 8 input values (fills 3-qubit state)
     let input: Vec<f64> = (1..=8).map(|x| x as f64).collect();
     let norm: f64 = input.iter().map(|x| x * x).sum::<f64>().sqrt();
+    let inv_norm = 1.0 / norm;
     let state_len = 8;
 
     let input_d = device.htod_copy(input.clone()).unwrap();
@@ -120,7 +122,7 @@ fn test_amplitude_encode_power_of_two() {
             *state_d.device_ptr_mut() as *mut std::ffi::c_void,
             input.len(),
             state_len,
-            norm,
+            inv_norm,
             std::ptr::null_mut(),
         )
     };
@@ -168,6 +170,7 @@ fn test_amplitude_encode_odd_input_length() {
     // Test with 3 input values, state size 4
     let input = vec![1.0, 2.0, 2.0];
     let norm = (1.0_f64 + 4.0 + 4.0).sqrt(); // 3.0
+    let inv_norm = 1.0 / norm;
     let state_len = 4;
 
     let input_d = device.htod_copy(input.clone()).unwrap();
@@ -179,7 +182,7 @@ fn test_amplitude_encode_odd_input_length() {
             *state_d.device_ptr_mut() as *mut std::ffi::c_void,
             input.len(),
             state_len,
-            norm,
+            inv_norm,
             std::ptr::null_mut(),
         )
     };
@@ -217,6 +220,7 @@ fn test_amplitude_encode_large_state() {
     let input_len = 1024;
     let input: Vec<f64> = (0..input_len).map(|i| (i + 1) as f64).collect();
     let norm: f64 = input.iter().map(|x| x * x).sum::<f64>().sqrt();
+    let inv_norm = 1.0 / norm;
     let state_len = 1024;
 
     let input_d = device.htod_copy(input.clone()).unwrap();
@@ -228,7 +232,7 @@ fn test_amplitude_encode_large_state() {
             *state_d.device_ptr_mut() as *mut std::ffi::c_void,
             input.len(),
             state_len,
-            norm,
+            inv_norm,
             std::ptr::null_mut(),
         )
     };
@@ -272,6 +276,7 @@ fn test_amplitude_encode_zero_norm_error() {
 
     let input = vec![0.0, 0.0, 0.0];
     let norm = 0.0; // Invalid!
+    let inv_norm = if norm == 0.0 { 0.0 } else { 1.0 / norm };
     let state_len = 4;
 
     let input_d = device.htod_copy(input).unwrap();
@@ -283,7 +288,7 @@ fn test_amplitude_encode_zero_norm_error() {
             *state_d.device_ptr_mut() as *mut std::ffi::c_void,
             3,
             state_len,
-            norm,
+            inv_norm,
             std::ptr::null_mut(),
         )
     };
@@ -311,6 +316,7 @@ fn test_amplitude_encode_negative_norm_error() {
 
     let input = vec![1.0, 2.0, 3.0];
     let norm = -5.0; // Invalid!
+    let inv_norm = if norm == 0.0 { 0.0 } else { 1.0 / norm };
     let state_len = 4;
 
     let input_d = device.htod_copy(input).unwrap();
@@ -322,7 +328,7 @@ fn test_amplitude_encode_negative_norm_error() {
             *state_d.device_ptr_mut() as *mut std::ffi::c_void,
             3,
             state_len,
-            norm,
+            inv_norm,
             std::ptr::null_mut(),
         )
     };
@@ -351,6 +357,7 @@ fn test_amplitude_encode_vectorized_load() {
     // Use exactly 16 elements to test vectorized loads (8 threads * 2 
elements each)
     let input: Vec<f64> = (1..=16).map(|x| x as f64).collect();
     let norm: f64 = input.iter().map(|x| x * x).sum::<f64>().sqrt();
+    let inv_norm = 1.0 / norm;
     let state_len = 16;
 
     let input_d = device.htod_copy(input.clone()).unwrap();
@@ -362,7 +369,7 @@ fn test_amplitude_encode_vectorized_load() {
             *state_d.device_ptr_mut() as *mut std::ffi::c_void,
             input.len(),
             state_len,
-            norm,
+            inv_norm,
             std::ptr::null_mut(),
         )
     };
@@ -402,6 +409,7 @@ fn test_amplitude_encode_small_input_large_state() {
     // Only 2 input values, but 16-element state (padding with zeros)
     let input = vec![3.0, 4.0];
     let norm = 5.0;
+    let inv_norm = 1.0 / norm;
     let state_len = 16;
 
     let input_d = device.htod_copy(input.clone()).unwrap();
@@ -413,7 +421,7 @@ fn test_amplitude_encode_small_input_large_state() {
             *state_d.device_ptr_mut() as *mut std::ffi::c_void,
             input.len(),
             state_len,
-            norm,
+            inv_norm,
             std::ptr::null_mut(),
         )
     };
@@ -438,6 +446,103 @@ fn test_amplitude_encode_small_input_large_state() {
     println!("PASS: Small input with large state padding works correctly");
 }
 
+#[test]
+#[cfg(target_os = "linux")]
+fn test_l2_norm_single_kernel() {
+    println!("Testing single-vector GPU norm reduction...");
+
+    let device = match CudaDevice::new(0) {
+        Ok(d) => d,
+        Err(_) => {
+            println!("SKIP: No CUDA device available");
+            return;
+        }
+    };
+
+    let input = vec![3.0f64, 4.0f64];
+    let expected_inv = 1.0 / 5.0;
+    let input_d = device.htod_copy(input.clone()).unwrap();
+    let mut inv_norm_d = device.alloc_zeros::<f64>(1).unwrap();
+
+    let result = unsafe {
+        launch_l2_norm(
+            *input_d.device_ptr() as *const f64,
+            input.len(),
+            *inv_norm_d.device_ptr_mut() as *mut f64,
+            std::ptr::null_mut(),
+        )
+    };
+
+    assert_eq!(result, 0, "Norm kernel should succeed");
+
+    let host = device.dtoh_sync_copy(&inv_norm_d).unwrap();
+    assert!(
+        (host[0] - expected_inv).abs() < EPSILON,
+        "Expected inv norm {}, got {}",
+        expected_inv,
+        host[0]
+    );
+
+    println!("PASS: Single-vector norm reduction matches CPU");
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_l2_norm_batch_kernel_stream() {
+    println!("Testing batched norm reduction on async stream...");
+
+    let device = match CudaDevice::new(0) {
+        Ok(d) => d,
+        Err(_) => {
+            println!("SKIP: No CUDA device available");
+            return;
+        }
+    };
+
+    // Two samples, four elements each
+    let sample_len = 4;
+    let num_samples = 2;
+    let input: Vec<f64> = vec![1.0, 2.0, 2.0, 1.0, 0.5, 0.5, 0.5, 0.5];
+    let expected: Vec<f64> = input
+        .chunks(sample_len)
+        .map(|chunk| {
+            let norm: f64 = chunk.iter().map(|v| v * v).sum::<f64>().sqrt();
+            1.0 / norm
+        })
+        .collect();
+
+    let stream = device.fork_default_stream().unwrap();
+    let input_d = device.htod_copy(input).unwrap();
+    let mut norms_d = device.alloc_zeros::<f64>(num_samples).unwrap();
+
+    let status = unsafe {
+        launch_l2_norm_batch(
+            *input_d.device_ptr() as *const f64,
+            num_samples,
+            sample_len,
+            *norms_d.device_ptr_mut() as *mut f64,
+            stream.stream as *mut std::ffi::c_void,
+        )
+    };
+
+    assert_eq!(status, 0, "Batch norm kernel should succeed");
+
+    device.wait_for(&stream).unwrap();
+    let norms_h = device.dtoh_sync_copy(&norms_d).unwrap();
+
+    for (i, (got, expect)) in norms_h.iter().zip(expected.iter()).enumerate() {
+        assert!(
+            (got - expect).abs() < EPSILON,
+            "Sample {} inv norm mismatch: expected {}, got {}",
+            i,
+            expect,
+            got
+        );
+    }
+
+    println!("PASS: Batched norm reduction on stream matches CPU");
+}
+
 #[test]
 #[cfg(not(target_os = "linux"))]
 fn test_amplitude_encode_dummy_non_linux() {

Reply via email to