This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new 745143223 MAHOUT-900 [IQP] Implement FWT optimization for IQP encoding 
(#938)
745143223 is described below

commit 74514322363c8b3a0ab77b4f61328470f65ffea6
Author: Ryan Huang <[email protected]>
AuthorDate: Mon Feb 2 01:47:50 2026 +0800

    MAHOUT-900 [IQP] Implement FWT optimization for IQP encoding (#938)
    
    * MAHOUT-900 [IQP] Implement FWT optimization for IQP encoding and add 
corresponding tests and benchmarks
    
    - Added Fast Walsh-Hadamard Transform (FWT) optimization to IQP encoding, 
reducing complexity from O(4^n) to O(n * 2^n).
    - Introduced new kernels for phase computation and FWT stages in both 
single and batch encoding.
    - Updated kernel configuration with thresholds for shared memory usage.
    - Created tests to verify normalization and correctness of FWT-optimized 
IQP states.
    - Added benchmarking script to measure performance improvements from FWT 
optimization.
    
    * linter
    
    * add test coverage bits
    
    * remove single-purpose benchmark
    
    Signed-off-by: Hsien-Cheng Huang <[email protected]>
    
    ---------
    
    Signed-off-by: Hsien-Cheng Huang <[email protected]>
---
 qdp/qdp-core/tests/iqp_encoding.rs  | 239 +++++++++++++++++++++++++
 qdp/qdp-kernels/src/iqp.cu          | 347 +++++++++++++++++++++++++++++++++++-
 qdp/qdp-kernels/src/kernel_config.h |  12 ++
 qdp/qdp-python/uv.lock              |  83 +++++++++
 testing/qdp/test_bindings.py        | 193 ++++++++++++++++++++
 uv.lock                             |   4 +
 6 files changed, 870 insertions(+), 8 deletions(-)

diff --git a/qdp/qdp-core/tests/iqp_encoding.rs 
b/qdp/qdp-core/tests/iqp_encoding.rs
index 4fc48bcc5..7f976d0ad 100644
--- a/qdp/qdp-core/tests/iqp_encoding.rs
+++ b/qdp/qdp-core/tests/iqp_encoding.rs
@@ -585,6 +585,245 @@ fn test_iqp_data_length_calculations() {
     println!("PASS: Data length calculations are correct");
 }
 
+// 
=============================================================================
+// FWT Optimization Correctness Tests
+// 
=============================================================================
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_fwt_threshold_boundary() {
+    println!("Testing IQP FWT threshold boundary (n=4, where FWT kicks 
in)...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    // Test at FWT_MIN_QUBITS threshold (n=4)
+    let num_qubits = 4;
+    let data: Vec<f64> = (0..iqp_full_data_len(num_qubits))
+        .map(|i| (i as f64) * 0.1)
+        .collect();
+
+    let result = engine.encode(&data, num_qubits, "iqp");
+    let dlpack_ptr = result.expect("IQP encoding at FWT threshold should 
succeed");
+    assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+    unsafe {
+        let managed = &*dlpack_ptr;
+        let tensor = &managed.dl_tensor;
+
+        assert_eq!(tensor.ndim, 2, "Tensor should be 2D");
+        let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim 
as usize);
+        assert_eq!(
+            shape_slice[1],
+            1 << num_qubits,
+            "Should have 2^n amplitudes"
+        );
+
+        println!(
+            "PASS: IQP FWT threshold boundary test with shape [{}, {}]",
+            shape_slice[0], shape_slice[1]
+        );
+
+        if let Some(deleter) = managed.deleter {
+            deleter(dlpack_ptr);
+        }
+    }
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_fwt_larger_qubit_counts() {
+    println!("Testing IQP FWT with larger qubit counts (n=5,6,7,8)...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    for num_qubits in [5, 6, 7, 8] {
+        let data: Vec<f64> = (0..iqp_full_data_len(num_qubits))
+            .map(|i| (i as f64) * 0.05)
+            .collect();
+
+        let result = engine.encode(&data, num_qubits, "iqp");
+        let dlpack_ptr = result
+            .unwrap_or_else(|_| panic!("IQP encoding for {} qubits should 
succeed", num_qubits));
+        assert!(!dlpack_ptr.is_null());
+
+        unsafe {
+            let managed = &*dlpack_ptr;
+            let tensor = &managed.dl_tensor;
+
+            let shape_slice = std::slice::from_raw_parts(tensor.shape, 
tensor.ndim as usize);
+            assert_eq!(
+                shape_slice[1],
+                (1 << num_qubits) as i64,
+                "Should have 2^{} amplitudes",
+                num_qubits
+            );
+
+            println!(
+                "  {} qubits: shape [{}, {}] - PASS",
+                num_qubits, shape_slice[0], shape_slice[1]
+            );
+
+            if let Some(deleter) = managed.deleter {
+                deleter(dlpack_ptr);
+            }
+        }
+    }
+
+    println!("PASS: IQP FWT larger qubit count tests completed");
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_z_fwt_correctness() {
+    println!("Testing IQP-Z FWT correctness for various qubit counts...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    // Test IQP-Z across FWT threshold
+    for num_qubits in [3, 4, 5, 6] {
+        let data: Vec<f64> = (0..iqp_z_data_len(num_qubits))
+            .map(|i| (i as f64) * 0.15)
+            .collect();
+
+        let result = engine.encode(&data, num_qubits, "iqp-z");
+        let dlpack_ptr = result
+            .unwrap_or_else(|_| panic!("IQP-Z encoding for {} qubits should 
succeed", num_qubits));
+        assert!(!dlpack_ptr.is_null());
+
+        unsafe {
+            let managed = &*dlpack_ptr;
+            let tensor = &managed.dl_tensor;
+
+            let shape_slice = std::slice::from_raw_parts(tensor.shape, 
tensor.ndim as usize);
+            assert_eq!(shape_slice[1], (1 << num_qubits) as i64);
+
+            println!(
+                "  IQP-Z {} qubits: shape [{}, {}] - PASS",
+                num_qubits, shape_slice[0], shape_slice[1]
+            );
+
+            if let Some(deleter) = managed.deleter {
+                deleter(dlpack_ptr);
+            }
+        }
+    }
+
+    println!("PASS: IQP-Z FWT correctness tests completed");
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_fwt_batch_various_sizes() {
+    println!("Testing IQP FWT batch encoding with various qubit counts...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    // Test batch encoding across FWT threshold
+    for num_qubits in [3, 4, 5, 6] {
+        let num_samples = 8;
+        let sample_size = iqp_full_data_len(num_qubits);
+
+        let batch_data: Vec<f64> = (0..num_samples * sample_size)
+            .map(|i| (i as f64) * 0.02)
+            .collect();
+
+        let result = engine.encode_batch(&batch_data, num_samples, 
sample_size, num_qubits, "iqp");
+        let dlpack_ptr = result.unwrap_or_else(|_| {
+            panic!(
+                "IQP batch encoding for {} qubits should succeed",
+                num_qubits
+            )
+        });
+        assert!(!dlpack_ptr.is_null());
+
+        unsafe {
+            let managed = &*dlpack_ptr;
+            let tensor = &managed.dl_tensor;
+
+            let shape_slice = std::slice::from_raw_parts(tensor.shape, 
tensor.ndim as usize);
+            assert_eq!(shape_slice[0], num_samples as i64);
+            assert_eq!(shape_slice[1], (1 << num_qubits) as i64);
+
+            println!(
+                "  IQP batch {} qubits x {} samples: shape [{}, {}] - PASS",
+                num_qubits, num_samples, shape_slice[0], shape_slice[1]
+            );
+
+            if let Some(deleter) = managed.deleter {
+                deleter(dlpack_ptr);
+            }
+        }
+    }
+
+    println!("PASS: IQP FWT batch encoding tests completed");
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_fwt_zero_parameters_identity() {
+    println!("Testing IQP FWT with zero parameters produces |0⟩ state...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    // For FWT-optimized path (n >= 4), zero parameters should still give |0⟩
+    for num_qubits in [4, 5, 6] {
+        let data: Vec<f64> = vec![0.0; iqp_full_data_len(num_qubits)];
+
+        let result = engine.encode(&data, num_qubits, "iqp");
+        let dlpack_ptr = result.expect("IQP encoding with zero params should 
succeed");
+        assert!(!dlpack_ptr.is_null());
+
+        unsafe {
+            let managed = &*dlpack_ptr;
+            let tensor = &managed.dl_tensor;
+
+            let shape_slice = std::slice::from_raw_parts(tensor.shape, 
tensor.ndim as usize);
+            assert_eq!(shape_slice[1], (1 << num_qubits) as i64);
+
+            println!(
+                "  IQP zero params {} qubits: verified shape - PASS",
+                num_qubits
+            );
+
+            if let Some(deleter) = managed.deleter {
+                deleter(dlpack_ptr);
+            }
+        }
+    }
+
+    println!("PASS: IQP FWT zero parameters test completed");
+}
+
 // 
=============================================================================
 // Encoder Factory Tests
 // 
=============================================================================
diff --git a/qdp/qdp-kernels/src/iqp.cu b/qdp/qdp-kernels/src/iqp.cu
index f0e63db50..9ec884a68 100644
--- a/qdp/qdp-kernels/src/iqp.cu
+++ b/qdp/qdp-kernels/src/iqp.cu
@@ -64,7 +64,11 @@ __device__ double compute_phase(
     return phase;
 }
 
-__global__ void iqp_encode_kernel(
+// ============================================================================
+// Naive O(4^n) Implementation (kept as fallback for small n and verification)
+// ============================================================================
+
+__global__ void iqp_encode_kernel_naive(
     const double* __restrict__ data,
     cuDoubleComplex* __restrict__ state,
     size_t state_len,
@@ -97,7 +101,134 @@ __global__ void iqp_encode_kernel(
     state[z] = make_cuDoubleComplex(real_sum * norm, imag_sum * norm);
 }
 
-__global__ void iqp_encode_batch_kernel(
+
+// ============================================================================
+// FWT O(n * 2^n) Implementation
+// ============================================================================
+
+// Step 1: Compute f[x] = exp(i*theta(x)) for all x
+// One thread per state, reuses existing compute_phase()
+__global__ void iqp_phase_kernel(
+    const double* __restrict__ data,
+    cuDoubleComplex* __restrict__ state,
+    size_t state_len,
+    unsigned int num_qubits,
+    int enable_zz
+) {
+    size_t x = blockIdx.x * blockDim.x + threadIdx.x;
+    if (x >= state_len) return;
+
+    double phase = compute_phase(data, x, num_qubits, enable_zz);
+
+    double cos_phase, sin_phase;
+    sincos(phase, &sin_phase, &cos_phase);
+    state[x] = make_cuDoubleComplex(cos_phase, sin_phase);
+}
+
+// Step 2a: FWT butterfly stage for global memory (n > threshold)
+// Each thread handles one butterfly pair per stage
+// Walsh-Hadamard butterfly: (a, b) -> (a + b, a - b)
+__global__ void fwt_butterfly_stage_kernel(
+    cuDoubleComplex* __restrict__ state,
+    size_t state_len,
+    unsigned int stage  // 0 to n-1
+) {
+    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+    // Each thread processes one butterfly pair
+    // For stage s, butterflies are separated by 2^s
+    size_t stride = 1ULL << stage;
+    size_t block_size = stride << 1;  // 2^(s+1)
+    size_t num_pairs = state_len >> 1;  // state_len / 2 total pairs
+
+    if (idx >= num_pairs) return;
+
+    // Compute which butterfly pair this thread handles
+    size_t block_idx = idx / stride;
+    size_t pair_offset = idx % stride;
+    size_t i = block_idx * block_size + pair_offset;
+    size_t j = i + stride;
+
+    // Load values
+    cuDoubleComplex a = state[i];
+    cuDoubleComplex b = state[j];
+
+    // Butterfly: (a, b) -> (a + b, a - b)
+    state[i] = cuCadd(a, b);
+    state[j] = cuCsub(a, b);
+}
+
+// Step 2b: FWT using shared memory (n <= threshold)
+// All stages in single kernel launch
+__global__ void fwt_shared_memory_kernel(
+    cuDoubleComplex* __restrict__ state,
+    size_t state_len,
+    unsigned int num_qubits
+) {
+    extern __shared__ cuDoubleComplex shared_state[];
+
+    size_t tid = threadIdx.x;
+    size_t bid = blockIdx.x;
+
+    // For shared memory FWT, we process the entire state in one block
+    // Block 0 handles the full transform
+    if (bid > 0) return;
+
+    // Load state into shared memory
+    for (size_t i = tid; i < state_len; i += blockDim.x) {
+        shared_state[i] = state[i];
+    }
+    __syncthreads();
+
+    // Perform all FWT stages in shared memory
+    for (unsigned int stage = 0; stage < num_qubits; ++stage) {
+        size_t stride = 1ULL << stage;
+        size_t block_size = stride << 1;
+        size_t num_pairs = state_len >> 1;
+
+        // Each thread handles multiple pairs if needed
+        for (size_t pair_idx = tid; pair_idx < num_pairs; pair_idx += 
blockDim.x) {
+            size_t block_idx = pair_idx / stride;
+            size_t pair_offset = pair_idx % stride;
+            size_t i = block_idx * block_size + pair_offset;
+            size_t j = i + stride;
+
+            cuDoubleComplex a = shared_state[i];
+            cuDoubleComplex b = shared_state[j];
+
+            shared_state[i] = cuCadd(a, b);
+            shared_state[j] = cuCsub(a, b);
+        }
+        __syncthreads();
+    }
+
+    // Write back to global memory
+    for (size_t i = tid; i < state_len; i += blockDim.x) {
+        state[i] = shared_state[i];
+    }
+}
+
+// Step 3: Normalize the state by 1/state_len (= 1/2^n)
+__global__ void normalize_state_kernel(
+    cuDoubleComplex* __restrict__ state,
+    size_t state_len,
+    double norm_factor
+) {
+    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx >= state_len) return;
+
+    cuDoubleComplex val = state[idx];
+    state[idx] = make_cuDoubleComplex(
+        cuCreal(val) * norm_factor,
+        cuCimag(val) * norm_factor
+    );
+}
+
+// ============================================================================
+// Naive O(4^n) Batch Implementation (kept as fallback)
+// ============================================================================
+
+__global__ void iqp_encode_batch_kernel_naive(
     const double* __restrict__ data_batch,
     cuDoubleComplex* __restrict__ state_batch,
     size_t num_samples,
@@ -139,9 +270,108 @@ __global__ void iqp_encode_batch_kernel(
     }
 }
 
+
+// ============================================================================
+// FWT O(n * 2^n) Batch Implementation
+// ============================================================================
+
+// Step 1: Compute f[x] = exp(i*theta(x)) for all x, for all samples in batch
+__global__ void iqp_phase_batch_kernel(
+    const double* __restrict__ data_batch,
+    cuDoubleComplex* __restrict__ state_batch,
+    size_t num_samples,
+    size_t state_len,
+    unsigned int num_qubits,
+    unsigned int data_len,
+    int enable_zz
+) {
+    const size_t total_elements = num_samples * state_len;
+    const size_t stride = gridDim.x * blockDim.x;
+    const size_t state_mask = state_len - 1;
+
+    for (size_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+         global_idx < total_elements;
+         global_idx += stride) {
+        const size_t sample_idx = global_idx >> num_qubits;
+        const size_t x = global_idx & state_mask;
+        const double* data = data_batch + sample_idx * data_len;
+
+        double phase = compute_phase(data, x, num_qubits, enable_zz);
+
+        double cos_phase, sin_phase;
+        sincos(phase, &sin_phase, &cos_phase);
+        state_batch[global_idx] = make_cuDoubleComplex(cos_phase, sin_phase);
+    }
+}
+
+// Step 2: FWT butterfly stage for batch (global memory)
+// Processes all samples in parallel
+__global__ void fwt_butterfly_batch_kernel(
+    cuDoubleComplex* __restrict__ state_batch,
+    size_t num_samples,
+    size_t state_len,
+    unsigned int num_qubits,
+    unsigned int stage
+) {
+    const size_t pairs_per_sample = state_len >> 1;
+    const size_t total_pairs = num_samples * pairs_per_sample;
+    const size_t grid_stride = gridDim.x * blockDim.x;
+
+    // For stage s, butterflies are separated by 2^s
+    const size_t stride = 1ULL << stage;
+    const size_t block_size = stride << 1;
+
+    for (size_t global_pair_idx = blockIdx.x * blockDim.x + threadIdx.x;
+         global_pair_idx < total_pairs;
+         global_pair_idx += grid_stride) {
+
+        // Determine which sample and which pair within that sample
+        const size_t sample_idx = global_pair_idx / pairs_per_sample;
+        const size_t pair_idx = global_pair_idx % pairs_per_sample;
+
+        // Compute indices within this sample's state
+        const size_t block_idx = pair_idx / stride;
+        const size_t pair_offset = pair_idx % stride;
+        const size_t local_i = block_idx * block_size + pair_offset;
+        const size_t local_j = local_i + stride;
+
+        // Global indices
+        const size_t base = sample_idx * state_len;
+        const size_t i = base + local_i;
+        const size_t j = base + local_j;
+
+        // Load values
+        cuDoubleComplex a = state_batch[i];
+        cuDoubleComplex b = state_batch[j];
+
+        // Butterfly: (a, b) -> (a + b, a - b)
+        state_batch[i] = cuCadd(a, b);
+        state_batch[j] = cuCsub(a, b);
+    }
+}
+
+// Step 3: Normalize all samples in batch
+__global__ void normalize_batch_kernel(
+    cuDoubleComplex* __restrict__ state_batch,
+    size_t total_elements,
+    double norm_factor
+) {
+    const size_t stride = gridDim.x * blockDim.x;
+
+    for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+         idx < total_elements;
+         idx += stride) {
+        cuDoubleComplex val = state_batch[idx];
+        state_batch[idx] = make_cuDoubleComplex(
+            cuCreal(val) * norm_factor,
+            cuCimag(val) * norm_factor
+        );
+    }
+}
+
 extern "C" {
 
-/// Launch IQP encoding kernel
+/// Launch IQP encoding kernel using FWT optimization
 ///
 /// # Arguments
 /// * data_d - Device pointer to encoding parameters
@@ -153,6 +383,13 @@ extern "C" {
 ///
 /// # Returns
 /// CUDA error code (0 = cudaSuccess)
+///
+/// # Algorithm
+/// For num_qubits >= FWT_MIN_QUBITS, uses Fast Walsh-Hadamard Transform:
+///   1. Phase computation: f[x] = exp(i*theta(x)) - O(2^n)
+///   2. FWT transform: WHT of phase array - O(n * 2^n)
+///   3. Normalization: divide by 2^n - O(2^n)
+/// Total: O(n * 2^n) vs naive O(4^n)
 int launch_iqp_encode(
     const double* data_d,
     void* state_d,
@@ -166,11 +403,26 @@ int launch_iqp_encode(
     }
 
     cuDoubleComplex* state_complex_d = static_cast<cuDoubleComplex*>(state_d);
-
     const int blockSize = DEFAULT_BLOCK_SIZE;
+
+    // Use naive kernel for small n (FWT overhead not worth it)
+    if (num_qubits < FWT_MIN_QUBITS) {
+        const int gridSize = (state_len + blockSize - 1) / blockSize;
+        iqp_encode_kernel_naive<<<gridSize, blockSize, 0, stream>>>(
+            data_d,
+            state_complex_d,
+            state_len,
+            num_qubits,
+            enable_zz
+        );
+        return (int)cudaGetLastError();
+    }
+
+    // FWT-based implementation for larger n
     const int gridSize = (state_len + blockSize - 1) / blockSize;
 
-    iqp_encode_kernel<<<gridSize, blockSize, 0, stream>>>(
+    // Step 1: Compute phase array f[x] = exp(i*theta(x))
+    iqp_phase_kernel<<<gridSize, blockSize, 0, stream>>>(
         data_d,
         state_complex_d,
         state_len,
@@ -178,10 +430,41 @@ int launch_iqp_encode(
         enable_zz
     );
 
+    // Step 2: Apply FWT
+    if (num_qubits <= FWT_SHARED_MEM_THRESHOLD) {
+        // Shared memory FWT - all stages in one kernel
+        size_t shared_mem_size = state_len * sizeof(cuDoubleComplex);
+        fwt_shared_memory_kernel<<<1, blockSize, shared_mem_size, stream>>>(
+            state_complex_d,
+            state_len,
+            num_qubits
+        );
+    } else {
+        // Global memory FWT - one kernel launch per stage
+        const size_t num_pairs = state_len >> 1;
+        const int fwt_grid_size = (num_pairs + blockSize - 1) / blockSize;
+
+        for (unsigned int stage = 0; stage < num_qubits; ++stage) {
+            fwt_butterfly_stage_kernel<<<fwt_grid_size, blockSize, 0, 
stream>>>(
+                state_complex_d,
+                state_len,
+                stage
+            );
+        }
+    }
+
+    // Step 3: Normalize by 1/2^n
+    double norm_factor = 1.0 / (double)state_len;
+    normalize_state_kernel<<<gridSize, blockSize, 0, stream>>>(
+        state_complex_d,
+        state_len,
+        norm_factor
+    );
+
     return (int)cudaGetLastError();
 }
 
-/// Launch batch IQP encoding kernel
+/// Launch batch IQP encoding kernel using FWT optimization
 ///
 /// # Arguments
 /// * data_batch_d - Device pointer to batch parameters (num_samples * 
data_len)
@@ -195,6 +478,13 @@ int launch_iqp_encode(
 ///
 /// # Returns
 /// CUDA error code (0 = cudaSuccess)
+///
+/// # Algorithm
+/// For num_qubits >= FWT_MIN_QUBITS, uses Fast Walsh-Hadamard Transform:
+///   1. Phase computation for all samples - O(batch * 2^n)
+///   2. FWT transform for all samples - O(batch * n * 2^n)
+///   3. Normalization - O(batch * 2^n)
+/// Total: O(batch * n * 2^n) vs naive O(batch * 4^n)
 int launch_iqp_encode_batch(
     const double* data_batch_d,
     void* state_batch_d,
@@ -210,13 +500,29 @@ int launch_iqp_encode_batch(
     }
 
     cuDoubleComplex* state_complex_d = 
static_cast<cuDoubleComplex*>(state_batch_d);
-
     const int blockSize = DEFAULT_BLOCK_SIZE;
     const size_t total_elements = num_samples * state_len;
     const size_t blocks_needed = (total_elements + blockSize - 1) / blockSize;
     const size_t gridSize = (blocks_needed < MAX_GRID_BLOCKS) ? blocks_needed 
: MAX_GRID_BLOCKS;
 
-    iqp_encode_batch_kernel<<<gridSize, blockSize, 0, stream>>>(
+    // Use naive kernel for small n (FWT overhead not worth it)
+    if (num_qubits < FWT_MIN_QUBITS) {
+        iqp_encode_batch_kernel_naive<<<gridSize, blockSize, 0, stream>>>(
+            data_batch_d,
+            state_complex_d,
+            num_samples,
+            state_len,
+            num_qubits,
+            data_len,
+            enable_zz
+        );
+        return (int)cudaGetLastError();
+    }
+
+    // FWT-based implementation for larger n
+
+    // Step 1: Compute phase array f[x] = exp(i*theta(x)) for all samples
+    iqp_phase_batch_kernel<<<gridSize, blockSize, 0, stream>>>(
         data_batch_d,
         state_complex_d,
         num_samples,
@@ -226,6 +532,31 @@ int launch_iqp_encode_batch(
         enable_zz
     );
 
+    // Step 2: Apply FWT to all samples (global memory version for batch)
+    // For batch processing, we always use global memory FWT
+    // (shared memory would require processing samples one at a time)
+    const size_t total_pairs = num_samples * (state_len >> 1);
+    const size_t fwt_blocks_needed = (total_pairs + blockSize - 1) / blockSize;
+    const size_t fwt_grid_size = (fwt_blocks_needed < MAX_GRID_BLOCKS) ? 
fwt_blocks_needed : MAX_GRID_BLOCKS;
+
+    for (unsigned int stage = 0; stage < num_qubits; ++stage) {
+        fwt_butterfly_batch_kernel<<<fwt_grid_size, blockSize, 0, stream>>>(
+            state_complex_d,
+            num_samples,
+            state_len,
+            num_qubits,
+            stage
+        );
+    }
+
+    // Step 3: Normalize by 1/2^n
+    double norm_factor = 1.0 / (double)state_len;
+    normalize_batch_kernel<<<gridSize, blockSize, 0, stream>>>(
+        state_complex_d,
+        total_elements,
+        norm_factor
+    );
+
     return (int)cudaGetLastError();
 }
 
diff --git a/qdp/qdp-kernels/src/kernel_config.h 
b/qdp/qdp-kernels/src/kernel_config.h
index b00ef169f..4ce4526cb 100644
--- a/qdp/qdp-kernels/src/kernel_config.h
+++ b/qdp/qdp-kernels/src/kernel_config.h
@@ -57,6 +57,18 @@
 // This limit ensures state vectors fit within practical GPU memory constraints
 #define MAX_QUBITS 30
 
+// ============================================================================
+// FWT (Fast Walsh-Hadamard Transform) Configuration
+// ============================================================================
+// Threshold for shared memory FWT optimization
+// For n <= this threshold, use shared memory FWT (single kernel launch)
+// For n > threshold, use global memory FWT (multiple kernel launches)
+// 10 qubits = 2^10 * 16 bytes (cuDoubleComplex) = 16KB shared memory
+#define FWT_SHARED_MEM_THRESHOLD 10
+
+// Minimum qubits to use FWT optimization (below this, naive is competitive)
+#define FWT_MIN_QUBITS 4
+
 // ============================================================================
 // Convenience Macros
 // ============================================================================
diff --git a/qdp/qdp-python/uv.lock b/qdp/qdp-python/uv.lock
index 658438764..b75377bd9 100644
--- a/qdp/qdp-python/uv.lock
+++ b/qdp/qdp-python/uv.lock
@@ -572,6 +572,18 @@ wheels = [
     { url = 
"https://files.pythonhosted.org/packages/be/95/03c8215f675349ff719cb44cd837c2468fdc0c05f55f523f3cad86bbdcc6/duet-0.2.9-py3-none-any.whl";,
 hash = 
"sha256:a16088b68b0faee8aee12cdf4d0a8af060ed958badb44f3e32f123f13f64119a", size 
= 29560, upload-time = "2023-07-26T06:38:58.931Z" },
 ]
 
+[[package]]
+name = "exceptiongroup"
+version = "1.3.1"
+source = { registry = "https://pypi.org/simple"; }
+dependencies = [
+    { name = "typing-extensions", marker = "python_full_version < '3.11'" },
+]
+sdist = { url = 
"https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz";,
 hash = 
"sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size 
= 30371, upload-time = "2025-11-21T23:01:54.787Z" }
+wheels = [
+    { url = 
"https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl";,
 hash = 
"sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size 
= 16740, upload-time = "2025-11-21T23:01:53.443Z" },
+]
+
 [[package]]
 name = "filelock"
 version = "3.20.0"
@@ -797,6 +809,15 @@ wheels = [
     { url = 
"https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl";,
 hash = 
"sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size 
= 71008, upload-time = "2025-10-12T14:55:18.883Z" },
 ]
 
+[[package]]
+name = "iniconfig"
+version = "2.3.0"
+source = { registry = "https://pypi.org/simple"; }
+sdist = { url = 
"https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz";,
 hash = 
"sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size 
= 20503, upload-time = "2025-10-18T21:55:43.219Z" }
+wheels = [
+    { url = 
"https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl";,
 hash = 
"sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size 
= 7484, upload-time = "2025-10-18T21:55:41.639Z" },
+]
+
 [[package]]
 name = "jinja2"
 version = "3.1.6"
@@ -1635,6 +1656,15 @@ wheels = [
     { url = 
"https://files.pythonhosted.org/packages/95/7e/f896623c3c635a90537ac093c6a618ebe1a90d87206e42309cb5d98a1b9e/pillow-12.0.0-pp311-pypy311_pp73-win_amd64.whl";,
 hash = 
"sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5", size 
= 6997850, upload-time = "2025-10-15T18:24:11.495Z" },
 ]
 
+[[package]]
+name = "pluggy"
+version = "1.6.0"
+source = { registry = "https://pypi.org/simple"; }
+sdist = { url = 
"https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz";,
 hash = 
"sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size 
= 69412, upload-time = "2025-05-15T12:30:07.975Z" }
+wheels = [
+    { url = 
"https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl";,
 hash = 
"sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size 
= 20538, upload-time = "2025-05-15T12:30:06.134Z" },
+]
+
 [[package]]
 name = "proto-plus"
 version = "1.27.0"
@@ -1836,6 +1866,24 @@ wheels = [
     { url = 
"https://files.pythonhosted.org/packages/8b/40/2614036cdd416452f5bf98ec037f38a1afb17f327cb8e6b652d4729e0af8/pyparsing-3.3.1-py3-none-any.whl";,
 hash = 
"sha256:023b5e7e5520ad96642e2c6db4cb683d3970bd640cdf7115049a6e9c3682df82", size 
= 121793, upload-time = "2025-12-23T03:14:02.103Z" },
 ]
 
+[[package]]
+name = "pytest"
+version = "9.0.2"
+source = { registry = "https://pypi.org/simple"; }
+dependencies = [
+    { name = "colorama", marker = "sys_platform == 'win32'" },
+    { name = "exceptiongroup", marker = "python_full_version < '3.11'" },
+    { name = "iniconfig" },
+    { name = "packaging" },
+    { name = "pluggy" },
+    { name = "pygments" },
+    { name = "tomli", marker = "python_full_version < '3.11'" },
+]
+sdist = { url = 
"https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz";,
 hash = 
"sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size 
= 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
+wheels = [
+    { url = 
"https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl";,
 hash = 
"sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size 
= 374801, upload-time = "2025-12-06T21:30:49.154Z" },
+]
+
 [[package]]
 name = "python-dateutil"
 version = "2.9.0.post0"
@@ -1980,6 +2028,10 @@ benchmark = [
     { name = "torch" },
     { name = "tqdm" },
 ]
+dev = [
+    { name = "pytest" },
+    { name = "torch" },
+]
 
 [package.metadata]
 requires-dist = [{ name = "qumat", editable = "../../" }]
@@ -1998,6 +2050,10 @@ benchmark = [
     { name = "torch", specifier = ">=2.2,<=2.9.0" },
     { name = "tqdm" },
 ]
+dev = [
+    { name = "pytest" },
+    { name = "torch", specifier = ">=2.2,<=2.9.0" },
+]
 
 [[package]]
 name = "requests"
@@ -2369,6 +2425,33 @@ wheels = [
     { url = 
"https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl";,
 hash = 
"sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size 
= 18638, upload-time = "2025-03-13T13:49:21.846Z" },
 ]
 
+[[package]]
+name = "tomli"
+version = "2.4.0"
+source = { registry = "https://pypi.org/simple"; }
+sdist = { url = 
"https://files.pythonhosted.org/packages/82/30/31573e9457673ab10aa432461bee537ce6cef177667deca369efb79df071/tomli-2.4.0.tar.gz";,
 hash = 
"sha256:aa89c3f6c277dd275d8e243ad24f3b5e701491a860d5121f2cdd399fbb31fc9c", size 
= 17477, upload-time = "2026-01-11T11:22:38.165Z" }
+wheels = [
+    { url = 
"https://files.pythonhosted.org/packages/3c/d9/3dc2289e1f3b32eb19b9785b6a006b28ee99acb37d1d47f78d4c10e28bf8/tomli-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl";,
 hash = 
"sha256:b5ef256a3fd497d4973c11bf142e9ed78b150d36f5773f1ca6088c230ffc5867", size 
= 153663, upload-time = "2026-01-11T11:21:45.27Z" },
+    { url = 
"https://files.pythonhosted.org/packages/51/32/ef9f6845e6b9ca392cd3f64f9ec185cc6f09f0a2df3db08cbe8809d1d435/tomli-2.4.0-cp311-cp311-macosx_11_0_arm64.whl";,
 hash = 
"sha256:5572e41282d5268eb09a697c89a7bee84fae66511f87533a6f88bd2f7b652da9", size 
= 148469, upload-time = "2026-01-11T11:21:46.873Z" },
+    { url = 
"https://files.pythonhosted.org/packages/d6/c2/506e44cce89a8b1b1e047d64bd495c22c9f71f21e05f380f1a950dd9c217/tomli-2.4.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl";,
 hash = 
"sha256:551e321c6ba03b55676970b47cb1b73f14a0a4dce6a3e1a9458fd6d921d72e95", size 
= 236039, upload-time = "2026-01-11T11:21:48.503Z" },
+    { url = 
"https://files.pythonhosted.org/packages/b3/40/e1b65986dbc861b7e986e8ec394598187fa8aee85b1650b01dd925ca0be8/tomli-2.4.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl";,
 hash = 
"sha256:5e3f639a7a8f10069d0e15408c0b96a2a828cfdec6fca05296ebcdcc28ca7c76", size 
= 243007, upload-time = "2026-01-11T11:21:49.456Z" },
+    { url = 
"https://files.pythonhosted.org/packages/9c/6f/6e39ce66b58a5b7ae572a0f4352ff40c71e8573633deda43f6a379d56b3e/tomli-2.4.0-cp311-cp311-musllinux_1_2_aarch64.whl";,
 hash = 
"sha256:1b168f2731796b045128c45982d3a4874057626da0e2ef1fdd722848b741361d", size 
= 240875, upload-time = "2026-01-11T11:21:50.755Z" },
+    { url = 
"https://files.pythonhosted.org/packages/aa/ad/cb089cb190487caa80204d503c7fd0f4d443f90b95cf4ef5cf5aa0f439b0/tomli-2.4.0-cp311-cp311-musllinux_1_2_x86_64.whl";,
 hash = 
"sha256:133e93646ec4300d651839d382d63edff11d8978be23da4cc106f5a18b7d0576", size 
= 246271, upload-time = "2026-01-11T11:21:51.81Z" },
+    { url = 
"https://files.pythonhosted.org/packages/0b/63/69125220e47fd7a3a27fd0de0c6398c89432fec41bc739823bcc66506af6/tomli-2.4.0-cp311-cp311-win32.whl";,
 hash = 
"sha256:b6c78bdf37764092d369722d9946cb65b8767bfa4110f902a1b2542d8d173c8a", size 
= 96770, upload-time = "2026-01-11T11:21:52.647Z" },
+    { url = 
"https://files.pythonhosted.org/packages/1e/0d/a22bb6c83f83386b0008425a6cd1fa1c14b5f3dd4bad05e98cf3dbbf4a64/tomli-2.4.0-cp311-cp311-win_amd64.whl";,
 hash = 
"sha256:d3d1654e11d724760cdb37a3d7691f0be9db5fbdaef59c9f532aabf87006dbaa", size 
= 107626, upload-time = "2026-01-11T11:21:53.459Z" },
+    { url = 
"https://files.pythonhosted.org/packages/2f/6d/77be674a3485e75cacbf2ddba2b146911477bd887dda9d8c9dfb2f15e871/tomli-2.4.0-cp311-cp311-win_arm64.whl";,
 hash = 
"sha256:cae9c19ed12d4e8f3ebf46d1a75090e4c0dc16271c5bce1c833ac168f08fb614", size 
= 94842, upload-time = "2026-01-11T11:21:54.831Z" },
+    { url = 
"https://files.pythonhosted.org/packages/3c/43/7389a1869f2f26dba52404e1ef13b4784b6b37dac93bac53457e3ff24ca3/tomli-2.4.0-cp312-cp312-macosx_10_13_x86_64.whl";,
 hash = 
"sha256:920b1de295e72887bafa3ad9f7a792f811847d57ea6b1215154030cf131f16b1", size 
= 154894, upload-time = "2026-01-11T11:21:56.07Z" },
+    { url = 
"https://files.pythonhosted.org/packages/e9/05/2f9bf110b5294132b2edf13fe6ca6ae456204f3d749f623307cbb7a946f2/tomli-2.4.0-cp312-cp312-macosx_11_0_arm64.whl";,
 hash = 
"sha256:7d6d9a4aee98fac3eab4952ad1d73aee87359452d1c086b5ceb43ed02ddb16b8", size 
= 149053, upload-time = "2026-01-11T11:21:57.467Z" },
+    { url = 
"https://files.pythonhosted.org/packages/e8/41/1eda3ca1abc6f6154a8db4d714a4d35c4ad90adc0bcf700657291593fbf3/tomli-2.4.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl";,
 hash = 
"sha256:36b9d05b51e65b254ea6c2585b59d2c4cb91c8a3d91d0ed0f17591a29aaea54a", size 
= 243481, upload-time = "2026-01-11T11:21:58.661Z" },
+    { url = 
"https://files.pythonhosted.org/packages/d2/6d/02ff5ab6c8868b41e7d4b987ce2b5f6a51d3335a70aa144edd999e055a01/tomli-2.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl";,
 hash = 
"sha256:1c8a885b370751837c029ef9bc014f27d80840e48bac415f3412e6593bbc18c1", size 
= 251720, upload-time = "2026-01-11T11:22:00.178Z" },
+    { url = 
"https://files.pythonhosted.org/packages/7b/57/0405c59a909c45d5b6f146107c6d997825aa87568b042042f7a9c0afed34/tomli-2.4.0-cp312-cp312-musllinux_1_2_aarch64.whl";,
 hash = 
"sha256:8768715ffc41f0008abe25d808c20c3d990f42b6e2e58305d5da280ae7d1fa3b", size 
= 247014, upload-time = "2026-01-11T11:22:01.238Z" },
+    { url = 
"https://files.pythonhosted.org/packages/2c/0e/2e37568edd944b4165735687cbaf2fe3648129e440c26d02223672ee0630/tomli-2.4.0-cp312-cp312-musllinux_1_2_x86_64.whl";,
 hash = 
"sha256:7b438885858efd5be02a9a133caf5812b8776ee0c969fea02c45e8e3f296ba51", size 
= 251820, upload-time = "2026-01-11T11:22:02.727Z" },
+    { url = 
"https://files.pythonhosted.org/packages/5a/1c/ee3b707fdac82aeeb92d1a113f803cf6d0f37bdca0849cb489553e1f417a/tomli-2.4.0-cp312-cp312-win32.whl";,
 hash = 
"sha256:0408e3de5ec77cc7f81960c362543cbbd91ef883e3138e81b729fc3eea5b9729", size 
= 97712, upload-time = "2026-01-11T11:22:03.777Z" },
+    { url = 
"https://files.pythonhosted.org/packages/69/13/c07a9177d0b3bab7913299b9278845fc6eaaca14a02667c6be0b0a2270c8/tomli-2.4.0-cp312-cp312-win_amd64.whl";,
 hash = 
"sha256:685306e2cc7da35be4ee914fd34ab801a6acacb061b6a7abca922aaf9ad368da", size 
= 108296, upload-time = "2026-01-11T11:22:04.86Z" },
+    { url = 
"https://files.pythonhosted.org/packages/18/27/e267a60bbeeee343bcc279bb9e8fbed0cbe224bc7b2a3dc2975f22809a09/tomli-2.4.0-cp312-cp312-win_arm64.whl";,
 hash = 
"sha256:5aa48d7c2356055feef06a43611fc401a07337d5b006be13a30f6c58f869e3c3", size 
= 94553, upload-time = "2026-01-11T11:22:05.854Z" },
+    { url = 
"https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl";,
 hash = 
"sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size 
= 14477, upload-time = "2026-01-11T11:22:37.446Z" },
+]
+
 [[package]]
 name = "tomlkit"
 version = "0.13.3"
diff --git a/testing/qdp/test_bindings.py b/testing/qdp/test_bindings.py
index d213d55cd..0e692a366 100644
--- a/testing/qdp/test_bindings.py
+++ b/testing/qdp/test_bindings.py
@@ -1099,3 +1099,196 @@ def test_iqp_encode_errors():
     # Non-finite parameter (negative infinity)
     with pytest.raises(RuntimeError, match="must be finite"):
         engine.encode([float("-inf"), 0.0], 2, "iqp-z")
+
+
+# ==================== IQP FWT Optimization Tests ====================
+
+
[email protected]
+def test_iqp_fwt_normalization():
+    """Test that FWT-optimized IQP produces normalized states (requires 
GPU)."""
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Test across FWT threshold (FWT_MIN_QUBITS = 4)
+    for num_qubits in [3, 4, 5, 6, 7, 8]:
+        # Full IQP (n + n*(n-1)/2 parameters)
+        data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+        data = [0.1 * i for i in range(data_len)]
+
+        qtensor = engine.encode(data, num_qubits, "iqp")
+        torch_tensor = torch.from_dlpack(qtensor)
+
+        # Verify normalization (sum of |amplitude|^2 = 1)
+        norm = torch.sum(torch.abs(torch_tensor) ** 2)
+        assert torch.isclose(norm, torch.tensor(1.0, device="cuda:0"), 
atol=1e-6), (
+            f"IQP {num_qubits} qubits not normalized: got {norm.item()}"
+        )
+
+
[email protected]
+def test_iqp_z_fwt_normalization():
+    """Test that FWT-optimized IQP-Z produces normalized states (requires 
GPU)."""
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Test across FWT threshold
+    for num_qubits in [3, 4, 5, 6, 7, 8]:
+        data = [0.2 * i for i in range(num_qubits)]
+
+        qtensor = engine.encode(data, num_qubits, "iqp-z")
+        torch_tensor = torch.from_dlpack(qtensor)
+
+        norm = torch.sum(torch.abs(torch_tensor) ** 2)
+        assert torch.isclose(norm, torch.tensor(1.0, device="cuda:0"), 
atol=1e-6), (
+            f"IQP-Z {num_qubits} qubits not normalized: got {norm.item()}"
+        )
+
+
[email protected]
+def test_iqp_fwt_zero_params_gives_zero_state():
+    """Test that zero parameters produce |0...0⟩ state (requires GPU).
+
+    With zero parameters, the IQP circuit is H^n * I * H^n = I,
+    so |0⟩^n maps to |0⟩^n with amplitude 1 at index 0.
+    """
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Test FWT-optimized path (n >= 4)
+    for num_qubits in [4, 5, 6]:
+        data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+        data = [0.0] * data_len
+
+        qtensor = engine.encode(data, num_qubits, "iqp")
+        torch_tensor = torch.from_dlpack(qtensor)
+
+        # Should get |0...0⟩: amplitude 1 at index 0, 0 elsewhere
+        state_len = 1 << num_qubits
+        expected = torch.zeros(
+            (1, state_len), dtype=torch_tensor.dtype, device="cuda:0"
+        )
+        expected[0, 0] = 1.0 + 0j
+
+        assert torch.allclose(torch_tensor, expected, atol=1e-6), (
+            f"IQP {num_qubits} qubits with zero params should give |0⟩ state"
+        )
+
+
[email protected]
+def test_iqp_fwt_batch_normalization():
+    """Test that FWT-optimized batch IQP produces normalized states (requires 
GPU)."""
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Test batch encoding across FWT threshold
+    for num_qubits in [4, 5, 6]:
+        data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+        batch_size = 8
+
+        data = torch.tensor(
+            [
+                [0.1 * (i + j * data_len) for i in range(data_len)]
+                for j in range(batch_size)
+            ],
+            dtype=torch.float64,
+        )
+
+        qtensor = engine.encode(data, num_qubits, "iqp")
+        torch_tensor = torch.from_dlpack(qtensor)
+
+        assert torch_tensor.shape == (batch_size, 1 << num_qubits)
+
+        # Check each sample is normalized
+        for i in range(batch_size):
+            norm = torch.sum(torch.abs(torch_tensor[i]) ** 2)
+            assert torch.isclose(norm, torch.tensor(1.0, device="cuda:0"), 
atol=1e-6), (
+                f"IQP batch sample {i} not normalized: got {norm.item()}"
+            )
+
+
[email protected]
+def test_iqp_fwt_deterministic():
+    """Test that FWT-optimized IQP is deterministic (requires GPU)."""
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    num_qubits = 6  # Uses FWT path
+    data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+    data = [0.3 * i for i in range(data_len)]
+
+    # Run encoding twice
+    qtensor1 = engine.encode(data, num_qubits, "iqp")
+    tensor1 = torch.from_dlpack(qtensor1).clone()
+
+    qtensor2 = engine.encode(data, num_qubits, "iqp")
+    tensor2 = torch.from_dlpack(qtensor2)
+
+    # Results should be identical
+    assert torch.allclose(tensor1, tensor2, atol=1e-10), (
+        "IQP FWT encoding should be deterministic"
+    )
+
+
[email protected]
+def test_iqp_fwt_shared_vs_global_memory_threshold():
+    """Test IQP encoding at shared memory threshold boundary (requires GPU).
+
+    FWT_SHARED_MEM_THRESHOLD = 10, so:
+    - n <= 10: uses shared memory FWT
+    - n > 10: uses global memory FWT
+    """
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Test at and around the shared memory threshold
+    # n <= 10: shared memory FWT, n > 10: global memory FWT (multi-launch)
+    for num_qubits in [9, 10, 11]:
+        data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+        data = [0.05 * i for i in range(data_len)]
+
+        qtensor = engine.encode(data, num_qubits, "iqp")
+        torch_tensor = torch.from_dlpack(qtensor)
+
+        assert torch_tensor.shape == (1, 1 << num_qubits)
+
+        norm = torch.sum(torch.abs(torch_tensor) ** 2)
+        assert torch.isclose(norm, torch.tensor(1.0, device="cuda:0"), 
atol=1e-6), (
+            f"IQP {num_qubits} qubits not normalized at threshold: got 
{norm.item()}"
+        )
diff --git a/uv.lock b/uv.lock
index f1124a64d..0f1a55e10 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1997,6 +1997,10 @@ benchmark = [
     { name = "torch", specifier = ">=2.2,<=2.9.0" },
     { name = "tqdm" },
 ]
+dev = [
+    { name = "pytest" },
+    { name = "torch", specifier = ">=2.2,<=2.9.0" },
+]
 
 [[package]]
 name = "requests"


Reply via email to