This is an automated email from the ASF dual-hosted git repository.
guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new 745143223 MAHOUT-900 [IQP] Implement FWT optimization for IQP encoding
(#938)
745143223 is described below
commit 74514322363c8b3a0ab77b4f61328470f65ffea6
Author: Ryan Huang <[email protected]>
AuthorDate: Mon Feb 2 01:47:50 2026 +0800
MAHOUT-900 [IQP] Implement FWT optimization for IQP encoding (#938)
* MAHOUT-900 [IQP] Implement FWT optimization for IQP encoding and add
corresponding tests and benchmarks
- Added Fast Walsh-Hadamard Transform (FWT) optimization to IQP encoding,
reducing complexity from O(4^n) to O(n * 2^n).
- Introduced new kernels for phase computation and FWT stages in both
single and batch encoding.
- Updated kernel configuration with thresholds for shared memory usage.
- Created tests to verify normalization and correctness of FWT-optimized
IQP states.
- Added benchmarking script to measure performance improvements from FWT
optimization.
* linter
* add test coverage bits
* remove single-purpose benchmark
Signed-off-by: Hsien-Cheng Huang <[email protected]>
---------
Signed-off-by: Hsien-Cheng Huang <[email protected]>
---
qdp/qdp-core/tests/iqp_encoding.rs | 239 +++++++++++++++++++++++++
qdp/qdp-kernels/src/iqp.cu | 347 +++++++++++++++++++++++++++++++++++-
qdp/qdp-kernels/src/kernel_config.h | 12 ++
qdp/qdp-python/uv.lock | 83 +++++++++
testing/qdp/test_bindings.py | 193 ++++++++++++++++++++
uv.lock | 4 +
6 files changed, 870 insertions(+), 8 deletions(-)
diff --git a/qdp/qdp-core/tests/iqp_encoding.rs
b/qdp/qdp-core/tests/iqp_encoding.rs
index 4fc48bcc5..7f976d0ad 100644
--- a/qdp/qdp-core/tests/iqp_encoding.rs
+++ b/qdp/qdp-core/tests/iqp_encoding.rs
@@ -585,6 +585,245 @@ fn test_iqp_data_length_calculations() {
println!("PASS: Data length calculations are correct");
}
+//
=============================================================================
+// FWT Optimization Correctness Tests
+//
=============================================================================
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_fwt_threshold_boundary() {
+ println!("Testing IQP FWT threshold boundary (n=4, where FWT kicks
in)...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => {
+ println!("SKIP: No GPU available");
+ return;
+ }
+ };
+
+ // Test at FWT_MIN_QUBITS threshold (n=4)
+ let num_qubits = 4;
+ let data: Vec<f64> = (0..iqp_full_data_len(num_qubits))
+ .map(|i| (i as f64) * 0.1)
+ .collect();
+
+ let result = engine.encode(&data, num_qubits, "iqp");
+ let dlpack_ptr = result.expect("IQP encoding at FWT threshold should
succeed");
+ assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+ unsafe {
+ let managed = &*dlpack_ptr;
+ let tensor = &managed.dl_tensor;
+
+ assert_eq!(tensor.ndim, 2, "Tensor should be 2D");
+ let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim
as usize);
+ assert_eq!(
+ shape_slice[1],
+ 1 << num_qubits,
+ "Should have 2^n amplitudes"
+ );
+
+ println!(
+ "PASS: IQP FWT threshold boundary test with shape [{}, {}]",
+ shape_slice[0], shape_slice[1]
+ );
+
+ if let Some(deleter) = managed.deleter {
+ deleter(dlpack_ptr);
+ }
+ }
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_fwt_larger_qubit_counts() {
+ println!("Testing IQP FWT with larger qubit counts (n=5,6,7,8)...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => {
+ println!("SKIP: No GPU available");
+ return;
+ }
+ };
+
+ for num_qubits in [5, 6, 7, 8] {
+ let data: Vec<f64> = (0..iqp_full_data_len(num_qubits))
+ .map(|i| (i as f64) * 0.05)
+ .collect();
+
+ let result = engine.encode(&data, num_qubits, "iqp");
+ let dlpack_ptr = result
+ .unwrap_or_else(|_| panic!("IQP encoding for {} qubits should
succeed", num_qubits));
+ assert!(!dlpack_ptr.is_null());
+
+ unsafe {
+ let managed = &*dlpack_ptr;
+ let tensor = &managed.dl_tensor;
+
+ let shape_slice = std::slice::from_raw_parts(tensor.shape,
tensor.ndim as usize);
+ assert_eq!(
+ shape_slice[1],
+ (1 << num_qubits) as i64,
+ "Should have 2^{} amplitudes",
+ num_qubits
+ );
+
+ println!(
+ " {} qubits: shape [{}, {}] - PASS",
+ num_qubits, shape_slice[0], shape_slice[1]
+ );
+
+ if let Some(deleter) = managed.deleter {
+ deleter(dlpack_ptr);
+ }
+ }
+ }
+
+ println!("PASS: IQP FWT larger qubit count tests completed");
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_z_fwt_correctness() {
+ println!("Testing IQP-Z FWT correctness for various qubit counts...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => {
+ println!("SKIP: No GPU available");
+ return;
+ }
+ };
+
+ // Test IQP-Z across FWT threshold
+ for num_qubits in [3, 4, 5, 6] {
+ let data: Vec<f64> = (0..iqp_z_data_len(num_qubits))
+ .map(|i| (i as f64) * 0.15)
+ .collect();
+
+ let result = engine.encode(&data, num_qubits, "iqp-z");
+ let dlpack_ptr = result
+ .unwrap_or_else(|_| panic!("IQP-Z encoding for {} qubits should
succeed", num_qubits));
+ assert!(!dlpack_ptr.is_null());
+
+ unsafe {
+ let managed = &*dlpack_ptr;
+ let tensor = &managed.dl_tensor;
+
+ let shape_slice = std::slice::from_raw_parts(tensor.shape,
tensor.ndim as usize);
+ assert_eq!(shape_slice[1], (1 << num_qubits) as i64);
+
+ println!(
+ " IQP-Z {} qubits: shape [{}, {}] - PASS",
+ num_qubits, shape_slice[0], shape_slice[1]
+ );
+
+ if let Some(deleter) = managed.deleter {
+ deleter(dlpack_ptr);
+ }
+ }
+ }
+
+ println!("PASS: IQP-Z FWT correctness tests completed");
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_fwt_batch_various_sizes() {
+ println!("Testing IQP FWT batch encoding with various qubit counts...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => {
+ println!("SKIP: No GPU available");
+ return;
+ }
+ };
+
+ // Test batch encoding across FWT threshold
+ for num_qubits in [3, 4, 5, 6] {
+ let num_samples = 8;
+ let sample_size = iqp_full_data_len(num_qubits);
+
+ let batch_data: Vec<f64> = (0..num_samples * sample_size)
+ .map(|i| (i as f64) * 0.02)
+ .collect();
+
+ let result = engine.encode_batch(&batch_data, num_samples,
sample_size, num_qubits, "iqp");
+ let dlpack_ptr = result.unwrap_or_else(|_| {
+ panic!(
+ "IQP batch encoding for {} qubits should succeed",
+ num_qubits
+ )
+ });
+ assert!(!dlpack_ptr.is_null());
+
+ unsafe {
+ let managed = &*dlpack_ptr;
+ let tensor = &managed.dl_tensor;
+
+ let shape_slice = std::slice::from_raw_parts(tensor.shape,
tensor.ndim as usize);
+ assert_eq!(shape_slice[0], num_samples as i64);
+ assert_eq!(shape_slice[1], (1 << num_qubits) as i64);
+
+ println!(
+ " IQP batch {} qubits x {} samples: shape [{}, {}] - PASS",
+ num_qubits, num_samples, shape_slice[0], shape_slice[1]
+ );
+
+ if let Some(deleter) = managed.deleter {
+ deleter(dlpack_ptr);
+ }
+ }
+ }
+
+ println!("PASS: IQP FWT batch encoding tests completed");
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_iqp_fwt_zero_parameters_identity() {
+ println!("Testing IQP FWT with zero parameters produces |0⟩ state...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => {
+ println!("SKIP: No GPU available");
+ return;
+ }
+ };
+
+ // For FWT-optimized path (n >= 4), zero parameters should still give |0⟩
+ for num_qubits in [4, 5, 6] {
+ let data: Vec<f64> = vec![0.0; iqp_full_data_len(num_qubits)];
+
+ let result = engine.encode(&data, num_qubits, "iqp");
+ let dlpack_ptr = result.expect("IQP encoding with zero params should
succeed");
+ assert!(!dlpack_ptr.is_null());
+
+ unsafe {
+ let managed = &*dlpack_ptr;
+ let tensor = &managed.dl_tensor;
+
+ let shape_slice = std::slice::from_raw_parts(tensor.shape,
tensor.ndim as usize);
+ assert_eq!(shape_slice[1], (1 << num_qubits) as i64);
+
+ println!(
+ " IQP zero params {} qubits: verified shape - PASS",
+ num_qubits
+ );
+
+ if let Some(deleter) = managed.deleter {
+ deleter(dlpack_ptr);
+ }
+ }
+ }
+
+ println!("PASS: IQP FWT zero parameters test completed");
+}
+
//
=============================================================================
// Encoder Factory Tests
//
=============================================================================
diff --git a/qdp/qdp-kernels/src/iqp.cu b/qdp/qdp-kernels/src/iqp.cu
index f0e63db50..9ec884a68 100644
--- a/qdp/qdp-kernels/src/iqp.cu
+++ b/qdp/qdp-kernels/src/iqp.cu
@@ -64,7 +64,11 @@ __device__ double compute_phase(
return phase;
}
-__global__ void iqp_encode_kernel(
+// ============================================================================
+// Naive O(4^n) Implementation (kept as fallback for small n and verification)
+// ============================================================================
+
+__global__ void iqp_encode_kernel_naive(
const double* __restrict__ data,
cuDoubleComplex* __restrict__ state,
size_t state_len,
@@ -97,7 +101,134 @@ __global__ void iqp_encode_kernel(
state[z] = make_cuDoubleComplex(real_sum * norm, imag_sum * norm);
}
-__global__ void iqp_encode_batch_kernel(
+
+// ============================================================================
+// FWT O(n * 2^n) Implementation
+// ============================================================================
+
+// Step 1: Compute f[x] = exp(i*theta(x)) for all x
+// One thread per state, reuses existing compute_phase()
+__global__ void iqp_phase_kernel(
+ const double* __restrict__ data,
+ cuDoubleComplex* __restrict__ state,
+ size_t state_len,
+ unsigned int num_qubits,
+ int enable_zz
+) {
+ size_t x = blockIdx.x * blockDim.x + threadIdx.x;
+ if (x >= state_len) return;
+
+ double phase = compute_phase(data, x, num_qubits, enable_zz);
+
+ double cos_phase, sin_phase;
+ sincos(phase, &sin_phase, &cos_phase);
+ state[x] = make_cuDoubleComplex(cos_phase, sin_phase);
+}
+
+// Step 2a: FWT butterfly stage for global memory (n > threshold)
+// Each thread handles one butterfly pair per stage
+// Walsh-Hadamard butterfly: (a, b) -> (a + b, a - b)
+__global__ void fwt_butterfly_stage_kernel(
+ cuDoubleComplex* __restrict__ state,
+ size_t state_len,
+ unsigned int stage // 0 to n-1
+) {
+ size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ // Each thread processes one butterfly pair
+ // For stage s, butterflies are separated by 2^s
+ size_t stride = 1ULL << stage;
+ size_t block_size = stride << 1; // 2^(s+1)
+ size_t num_pairs = state_len >> 1; // state_len / 2 total pairs
+
+ if (idx >= num_pairs) return;
+
+ // Compute which butterfly pair this thread handles
+ size_t block_idx = idx / stride;
+ size_t pair_offset = idx % stride;
+ size_t i = block_idx * block_size + pair_offset;
+ size_t j = i + stride;
+
+ // Load values
+ cuDoubleComplex a = state[i];
+ cuDoubleComplex b = state[j];
+
+ // Butterfly: (a, b) -> (a + b, a - b)
+ state[i] = cuCadd(a, b);
+ state[j] = cuCsub(a, b);
+}
+
+// Step 2b: FWT using shared memory (n <= threshold)
+// All stages in single kernel launch
+__global__ void fwt_shared_memory_kernel(
+ cuDoubleComplex* __restrict__ state,
+ size_t state_len,
+ unsigned int num_qubits
+) {
+ extern __shared__ cuDoubleComplex shared_state[];
+
+ size_t tid = threadIdx.x;
+ size_t bid = blockIdx.x;
+
+ // For shared memory FWT, we process the entire state in one block
+ // Block 0 handles the full transform
+ if (bid > 0) return;
+
+ // Load state into shared memory
+ for (size_t i = tid; i < state_len; i += blockDim.x) {
+ shared_state[i] = state[i];
+ }
+ __syncthreads();
+
+ // Perform all FWT stages in shared memory
+ for (unsigned int stage = 0; stage < num_qubits; ++stage) {
+ size_t stride = 1ULL << stage;
+ size_t block_size = stride << 1;
+ size_t num_pairs = state_len >> 1;
+
+ // Each thread handles multiple pairs if needed
+ for (size_t pair_idx = tid; pair_idx < num_pairs; pair_idx +=
blockDim.x) {
+ size_t block_idx = pair_idx / stride;
+ size_t pair_offset = pair_idx % stride;
+ size_t i = block_idx * block_size + pair_offset;
+ size_t j = i + stride;
+
+ cuDoubleComplex a = shared_state[i];
+ cuDoubleComplex b = shared_state[j];
+
+ shared_state[i] = cuCadd(a, b);
+ shared_state[j] = cuCsub(a, b);
+ }
+ __syncthreads();
+ }
+
+ // Write back to global memory
+ for (size_t i = tid; i < state_len; i += blockDim.x) {
+ state[i] = shared_state[i];
+ }
+}
+
+// Step 3: Normalize the state by 1/state_len (= 1/2^n)
+__global__ void normalize_state_kernel(
+ cuDoubleComplex* __restrict__ state,
+ size_t state_len,
+ double norm_factor
+) {
+ size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= state_len) return;
+
+ cuDoubleComplex val = state[idx];
+ state[idx] = make_cuDoubleComplex(
+ cuCreal(val) * norm_factor,
+ cuCimag(val) * norm_factor
+ );
+}
+
+// ============================================================================
+// Naive O(4^n) Batch Implementation (kept as fallback)
+// ============================================================================
+
+__global__ void iqp_encode_batch_kernel_naive(
const double* __restrict__ data_batch,
cuDoubleComplex* __restrict__ state_batch,
size_t num_samples,
@@ -139,9 +270,108 @@ __global__ void iqp_encode_batch_kernel(
}
}
+
+// ============================================================================
+// FWT O(n * 2^n) Batch Implementation
+// ============================================================================
+
+// Step 1: Compute f[x] = exp(i*theta(x)) for all x, for all samples in batch
+__global__ void iqp_phase_batch_kernel(
+ const double* __restrict__ data_batch,
+ cuDoubleComplex* __restrict__ state_batch,
+ size_t num_samples,
+ size_t state_len,
+ unsigned int num_qubits,
+ unsigned int data_len,
+ int enable_zz
+) {
+ const size_t total_elements = num_samples * state_len;
+ const size_t stride = gridDim.x * blockDim.x;
+ const size_t state_mask = state_len - 1;
+
+ for (size_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ global_idx < total_elements;
+ global_idx += stride) {
+ const size_t sample_idx = global_idx >> num_qubits;
+ const size_t x = global_idx & state_mask;
+ const double* data = data_batch + sample_idx * data_len;
+
+ double phase = compute_phase(data, x, num_qubits, enable_zz);
+
+ double cos_phase, sin_phase;
+ sincos(phase, &sin_phase, &cos_phase);
+ state_batch[global_idx] = make_cuDoubleComplex(cos_phase, sin_phase);
+ }
+}
+
+// Step 2: FWT butterfly stage for batch (global memory)
+// Processes all samples in parallel
+__global__ void fwt_butterfly_batch_kernel(
+ cuDoubleComplex* __restrict__ state_batch,
+ size_t num_samples,
+ size_t state_len,
+ unsigned int num_qubits,
+ unsigned int stage
+) {
+ const size_t pairs_per_sample = state_len >> 1;
+ const size_t total_pairs = num_samples * pairs_per_sample;
+ const size_t grid_stride = gridDim.x * blockDim.x;
+
+ // For stage s, butterflies are separated by 2^s
+ const size_t stride = 1ULL << stage;
+ const size_t block_size = stride << 1;
+
+ for (size_t global_pair_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ global_pair_idx < total_pairs;
+ global_pair_idx += grid_stride) {
+
+ // Determine which sample and which pair within that sample
+ const size_t sample_idx = global_pair_idx / pairs_per_sample;
+ const size_t pair_idx = global_pair_idx % pairs_per_sample;
+
+ // Compute indices within this sample's state
+ const size_t block_idx = pair_idx / stride;
+ const size_t pair_offset = pair_idx % stride;
+ const size_t local_i = block_idx * block_size + pair_offset;
+ const size_t local_j = local_i + stride;
+
+ // Global indices
+ const size_t base = sample_idx * state_len;
+ const size_t i = base + local_i;
+ const size_t j = base + local_j;
+
+ // Load values
+ cuDoubleComplex a = state_batch[i];
+ cuDoubleComplex b = state_batch[j];
+
+ // Butterfly: (a, b) -> (a + b, a - b)
+ state_batch[i] = cuCadd(a, b);
+ state_batch[j] = cuCsub(a, b);
+ }
+}
+
+// Step 3: Normalize all samples in batch
+__global__ void normalize_batch_kernel(
+ cuDoubleComplex* __restrict__ state_batch,
+ size_t total_elements,
+ double norm_factor
+) {
+ const size_t stride = gridDim.x * blockDim.x;
+
+ for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ idx < total_elements;
+ idx += stride) {
+ cuDoubleComplex val = state_batch[idx];
+ state_batch[idx] = make_cuDoubleComplex(
+ cuCreal(val) * norm_factor,
+ cuCimag(val) * norm_factor
+ );
+ }
+}
+
extern "C" {
-/// Launch IQP encoding kernel
+/// Launch IQP encoding kernel using FWT optimization
///
/// # Arguments
/// * data_d - Device pointer to encoding parameters
@@ -153,6 +383,13 @@ extern "C" {
///
/// # Returns
/// CUDA error code (0 = cudaSuccess)
+///
+/// # Algorithm
+/// For num_qubits >= FWT_MIN_QUBITS, uses Fast Walsh-Hadamard Transform:
+/// 1. Phase computation: f[x] = exp(i*theta(x)) - O(2^n)
+/// 2. FWT transform: WHT of phase array - O(n * 2^n)
+/// 3. Normalization: divide by 2^n - O(2^n)
+/// Total: O(n * 2^n) vs naive O(4^n)
int launch_iqp_encode(
const double* data_d,
void* state_d,
@@ -166,11 +403,26 @@ int launch_iqp_encode(
}
cuDoubleComplex* state_complex_d = static_cast<cuDoubleComplex*>(state_d);
-
const int blockSize = DEFAULT_BLOCK_SIZE;
+
+ // Use naive kernel for small n (FWT overhead not worth it)
+ if (num_qubits < FWT_MIN_QUBITS) {
+ const int gridSize = (state_len + blockSize - 1) / blockSize;
+ iqp_encode_kernel_naive<<<gridSize, blockSize, 0, stream>>>(
+ data_d,
+ state_complex_d,
+ state_len,
+ num_qubits,
+ enable_zz
+ );
+ return (int)cudaGetLastError();
+ }
+
+ // FWT-based implementation for larger n
const int gridSize = (state_len + blockSize - 1) / blockSize;
- iqp_encode_kernel<<<gridSize, blockSize, 0, stream>>>(
+ // Step 1: Compute phase array f[x] = exp(i*theta(x))
+ iqp_phase_kernel<<<gridSize, blockSize, 0, stream>>>(
data_d,
state_complex_d,
state_len,
@@ -178,10 +430,41 @@ int launch_iqp_encode(
enable_zz
);
+ // Step 2: Apply FWT
+ if (num_qubits <= FWT_SHARED_MEM_THRESHOLD) {
+ // Shared memory FWT - all stages in one kernel
+ size_t shared_mem_size = state_len * sizeof(cuDoubleComplex);
+ fwt_shared_memory_kernel<<<1, blockSize, shared_mem_size, stream>>>(
+ state_complex_d,
+ state_len,
+ num_qubits
+ );
+ } else {
+ // Global memory FWT - one kernel launch per stage
+ const size_t num_pairs = state_len >> 1;
+ const int fwt_grid_size = (num_pairs + blockSize - 1) / blockSize;
+
+ for (unsigned int stage = 0; stage < num_qubits; ++stage) {
+ fwt_butterfly_stage_kernel<<<fwt_grid_size, blockSize, 0,
stream>>>(
+ state_complex_d,
+ state_len,
+ stage
+ );
+ }
+ }
+
+ // Step 3: Normalize by 1/2^n
+ double norm_factor = 1.0 / (double)state_len;
+ normalize_state_kernel<<<gridSize, blockSize, 0, stream>>>(
+ state_complex_d,
+ state_len,
+ norm_factor
+ );
+
return (int)cudaGetLastError();
}
-/// Launch batch IQP encoding kernel
+/// Launch batch IQP encoding kernel using FWT optimization
///
/// # Arguments
/// * data_batch_d - Device pointer to batch parameters (num_samples *
data_len)
@@ -195,6 +478,13 @@ int launch_iqp_encode(
///
/// # Returns
/// CUDA error code (0 = cudaSuccess)
+///
+/// # Algorithm
+/// For num_qubits >= FWT_MIN_QUBITS, uses Fast Walsh-Hadamard Transform:
+/// 1. Phase computation for all samples - O(batch * 2^n)
+/// 2. FWT transform for all samples - O(batch * n * 2^n)
+/// 3. Normalization - O(batch * 2^n)
+/// Total: O(batch * n * 2^n) vs naive O(batch * 4^n)
int launch_iqp_encode_batch(
const double* data_batch_d,
void* state_batch_d,
@@ -210,13 +500,29 @@ int launch_iqp_encode_batch(
}
cuDoubleComplex* state_complex_d =
static_cast<cuDoubleComplex*>(state_batch_d);
-
const int blockSize = DEFAULT_BLOCK_SIZE;
const size_t total_elements = num_samples * state_len;
const size_t blocks_needed = (total_elements + blockSize - 1) / blockSize;
const size_t gridSize = (blocks_needed < MAX_GRID_BLOCKS) ? blocks_needed
: MAX_GRID_BLOCKS;
- iqp_encode_batch_kernel<<<gridSize, blockSize, 0, stream>>>(
+ // Use naive kernel for small n (FWT overhead not worth it)
+ if (num_qubits < FWT_MIN_QUBITS) {
+ iqp_encode_batch_kernel_naive<<<gridSize, blockSize, 0, stream>>>(
+ data_batch_d,
+ state_complex_d,
+ num_samples,
+ state_len,
+ num_qubits,
+ data_len,
+ enable_zz
+ );
+ return (int)cudaGetLastError();
+ }
+
+ // FWT-based implementation for larger n
+
+ // Step 1: Compute phase array f[x] = exp(i*theta(x)) for all samples
+ iqp_phase_batch_kernel<<<gridSize, blockSize, 0, stream>>>(
data_batch_d,
state_complex_d,
num_samples,
@@ -226,6 +532,31 @@ int launch_iqp_encode_batch(
enable_zz
);
+ // Step 2: Apply FWT to all samples (global memory version for batch)
+ // For batch processing, we always use global memory FWT
+ // (shared memory would require processing samples one at a time)
+ const size_t total_pairs = num_samples * (state_len >> 1);
+ const size_t fwt_blocks_needed = (total_pairs + blockSize - 1) / blockSize;
+ const size_t fwt_grid_size = (fwt_blocks_needed < MAX_GRID_BLOCKS) ?
fwt_blocks_needed : MAX_GRID_BLOCKS;
+
+ for (unsigned int stage = 0; stage < num_qubits; ++stage) {
+ fwt_butterfly_batch_kernel<<<fwt_grid_size, blockSize, 0, stream>>>(
+ state_complex_d,
+ num_samples,
+ state_len,
+ num_qubits,
+ stage
+ );
+ }
+
+ // Step 3: Normalize by 1/2^n
+ double norm_factor = 1.0 / (double)state_len;
+ normalize_batch_kernel<<<gridSize, blockSize, 0, stream>>>(
+ state_complex_d,
+ total_elements,
+ norm_factor
+ );
+
return (int)cudaGetLastError();
}
diff --git a/qdp/qdp-kernels/src/kernel_config.h
b/qdp/qdp-kernels/src/kernel_config.h
index b00ef169f..4ce4526cb 100644
--- a/qdp/qdp-kernels/src/kernel_config.h
+++ b/qdp/qdp-kernels/src/kernel_config.h
@@ -57,6 +57,18 @@
// This limit ensures state vectors fit within practical GPU memory constraints
#define MAX_QUBITS 30
+// ============================================================================
+// FWT (Fast Walsh-Hadamard Transform) Configuration
+// ============================================================================
+// Threshold for shared memory FWT optimization
+// For n <= this threshold, use shared memory FWT (single kernel launch)
+// For n > threshold, use global memory FWT (multiple kernel launches)
+// 10 qubits = 2^10 * 16 bytes (cuDoubleComplex) = 16KB shared memory
+#define FWT_SHARED_MEM_THRESHOLD 10
+
+// Minimum qubits to use FWT optimization (below this, naive is competitive)
+#define FWT_MIN_QUBITS 4
+
// ============================================================================
// Convenience Macros
// ============================================================================
diff --git a/qdp/qdp-python/uv.lock b/qdp/qdp-python/uv.lock
index 658438764..b75377bd9 100644
--- a/qdp/qdp-python/uv.lock
+++ b/qdp/qdp-python/uv.lock
@@ -572,6 +572,18 @@ wheels = [
{ url =
"https://files.pythonhosted.org/packages/be/95/03c8215f675349ff719cb44cd837c2468fdc0c05f55f523f3cad86bbdcc6/duet-0.2.9-py3-none-any.whl",
hash =
"sha256:a16088b68b0faee8aee12cdf4d0a8af060ed958badb44f3e32f123f13f64119a", size
= 29560, upload-time = "2023-07-26T06:38:58.931Z" },
]
+[[package]]
+name = "exceptiongroup"
+version = "1.3.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "typing-extensions", marker = "python_full_version < '3.11'" },
+]
+sdist = { url =
"https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz",
hash =
"sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size
= 30371, upload-time = "2025-11-21T23:01:54.787Z" }
+wheels = [
+ { url =
"https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl",
hash =
"sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size
= 16740, upload-time = "2025-11-21T23:01:53.443Z" },
+]
+
[[package]]
name = "filelock"
version = "3.20.0"
@@ -797,6 +809,15 @@ wheels = [
{ url =
"https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl",
hash =
"sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size
= 71008, upload-time = "2025-10-12T14:55:18.883Z" },
]
+[[package]]
+name = "iniconfig"
+version = "2.3.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url =
"https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz",
hash =
"sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size
= 20503, upload-time = "2025-10-18T21:55:43.219Z" }
+wheels = [
+ { url =
"https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl",
hash =
"sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size
= 7484, upload-time = "2025-10-18T21:55:41.639Z" },
+]
+
[[package]]
name = "jinja2"
version = "3.1.6"
@@ -1635,6 +1656,15 @@ wheels = [
{ url =
"https://files.pythonhosted.org/packages/95/7e/f896623c3c635a90537ac093c6a618ebe1a90d87206e42309cb5d98a1b9e/pillow-12.0.0-pp311-pypy311_pp73-win_amd64.whl",
hash =
"sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5", size
= 6997850, upload-time = "2025-10-15T18:24:11.495Z" },
]
+[[package]]
+name = "pluggy"
+version = "1.6.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url =
"https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz",
hash =
"sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size
= 69412, upload-time = "2025-05-15T12:30:07.975Z" }
+wheels = [
+ { url =
"https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl",
hash =
"sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size
= 20538, upload-time = "2025-05-15T12:30:06.134Z" },
+]
+
[[package]]
name = "proto-plus"
version = "1.27.0"
@@ -1836,6 +1866,24 @@ wheels = [
{ url =
"https://files.pythonhosted.org/packages/8b/40/2614036cdd416452f5bf98ec037f38a1afb17f327cb8e6b652d4729e0af8/pyparsing-3.3.1-py3-none-any.whl",
hash =
"sha256:023b5e7e5520ad96642e2c6db4cb683d3970bd640cdf7115049a6e9c3682df82", size
= 121793, upload-time = "2025-12-23T03:14:02.103Z" },
]
+[[package]]
+name = "pytest"
+version = "9.0.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "exceptiongroup", marker = "python_full_version < '3.11'" },
+ { name = "iniconfig" },
+ { name = "packaging" },
+ { name = "pluggy" },
+ { name = "pygments" },
+ { name = "tomli", marker = "python_full_version < '3.11'" },
+]
+sdist = { url =
"https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz",
hash =
"sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size
= 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
+wheels = [
+ { url =
"https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl",
hash =
"sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size
= 374801, upload-time = "2025-12-06T21:30:49.154Z" },
+]
+
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@@ -1980,6 +2028,10 @@ benchmark = [
{ name = "torch" },
{ name = "tqdm" },
]
+dev = [
+ { name = "pytest" },
+ { name = "torch" },
+]
[package.metadata]
requires-dist = [{ name = "qumat", editable = "../../" }]
@@ -1998,6 +2050,10 @@ benchmark = [
{ name = "torch", specifier = ">=2.2,<=2.9.0" },
{ name = "tqdm" },
]
+dev = [
+ { name = "pytest" },
+ { name = "torch", specifier = ">=2.2,<=2.9.0" },
+]
[[package]]
name = "requests"
@@ -2369,6 +2425,33 @@ wheels = [
{ url =
"https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl",
hash =
"sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size
= 18638, upload-time = "2025-03-13T13:49:21.846Z" },
]
+[[package]]
+name = "tomli"
+version = "2.4.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url =
"https://files.pythonhosted.org/packages/82/30/31573e9457673ab10aa432461bee537ce6cef177667deca369efb79df071/tomli-2.4.0.tar.gz",
hash =
"sha256:aa89c3f6c277dd275d8e243ad24f3b5e701491a860d5121f2cdd399fbb31fc9c", size
= 17477, upload-time = "2026-01-11T11:22:38.165Z" }
+wheels = [
+ { url =
"https://files.pythonhosted.org/packages/3c/d9/3dc2289e1f3b32eb19b9785b6a006b28ee99acb37d1d47f78d4c10e28bf8/tomli-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl",
hash =
"sha256:b5ef256a3fd497d4973c11bf142e9ed78b150d36f5773f1ca6088c230ffc5867", size
= 153663, upload-time = "2026-01-11T11:21:45.27Z" },
+ { url =
"https://files.pythonhosted.org/packages/51/32/ef9f6845e6b9ca392cd3f64f9ec185cc6f09f0a2df3db08cbe8809d1d435/tomli-2.4.0-cp311-cp311-macosx_11_0_arm64.whl",
hash =
"sha256:5572e41282d5268eb09a697c89a7bee84fae66511f87533a6f88bd2f7b652da9", size
= 148469, upload-time = "2026-01-11T11:21:46.873Z" },
+ { url =
"https://files.pythonhosted.org/packages/d6/c2/506e44cce89a8b1b1e047d64bd495c22c9f71f21e05f380f1a950dd9c217/tomli-2.4.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl",
hash =
"sha256:551e321c6ba03b55676970b47cb1b73f14a0a4dce6a3e1a9458fd6d921d72e95", size
= 236039, upload-time = "2026-01-11T11:21:48.503Z" },
+ { url =
"https://files.pythonhosted.org/packages/b3/40/e1b65986dbc861b7e986e8ec394598187fa8aee85b1650b01dd925ca0be8/tomli-2.4.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl",
hash =
"sha256:5e3f639a7a8f10069d0e15408c0b96a2a828cfdec6fca05296ebcdcc28ca7c76", size
= 243007, upload-time = "2026-01-11T11:21:49.456Z" },
+ { url =
"https://files.pythonhosted.org/packages/9c/6f/6e39ce66b58a5b7ae572a0f4352ff40c71e8573633deda43f6a379d56b3e/tomli-2.4.0-cp311-cp311-musllinux_1_2_aarch64.whl",
hash =
"sha256:1b168f2731796b045128c45982d3a4874057626da0e2ef1fdd722848b741361d", size
= 240875, upload-time = "2026-01-11T11:21:50.755Z" },
+ { url =
"https://files.pythonhosted.org/packages/aa/ad/cb089cb190487caa80204d503c7fd0f4d443f90b95cf4ef5cf5aa0f439b0/tomli-2.4.0-cp311-cp311-musllinux_1_2_x86_64.whl",
hash =
"sha256:133e93646ec4300d651839d382d63edff11d8978be23da4cc106f5a18b7d0576", size
= 246271, upload-time = "2026-01-11T11:21:51.81Z" },
+ { url =
"https://files.pythonhosted.org/packages/0b/63/69125220e47fd7a3a27fd0de0c6398c89432fec41bc739823bcc66506af6/tomli-2.4.0-cp311-cp311-win32.whl",
hash =
"sha256:b6c78bdf37764092d369722d9946cb65b8767bfa4110f902a1b2542d8d173c8a", size
= 96770, upload-time = "2026-01-11T11:21:52.647Z" },
+ { url =
"https://files.pythonhosted.org/packages/1e/0d/a22bb6c83f83386b0008425a6cd1fa1c14b5f3dd4bad05e98cf3dbbf4a64/tomli-2.4.0-cp311-cp311-win_amd64.whl",
hash =
"sha256:d3d1654e11d724760cdb37a3d7691f0be9db5fbdaef59c9f532aabf87006dbaa", size
= 107626, upload-time = "2026-01-11T11:21:53.459Z" },
+ { url =
"https://files.pythonhosted.org/packages/2f/6d/77be674a3485e75cacbf2ddba2b146911477bd887dda9d8c9dfb2f15e871/tomli-2.4.0-cp311-cp311-win_arm64.whl",
hash =
"sha256:cae9c19ed12d4e8f3ebf46d1a75090e4c0dc16271c5bce1c833ac168f08fb614", size
= 94842, upload-time = "2026-01-11T11:21:54.831Z" },
+ { url =
"https://files.pythonhosted.org/packages/3c/43/7389a1869f2f26dba52404e1ef13b4784b6b37dac93bac53457e3ff24ca3/tomli-2.4.0-cp312-cp312-macosx_10_13_x86_64.whl",
hash =
"sha256:920b1de295e72887bafa3ad9f7a792f811847d57ea6b1215154030cf131f16b1", size
= 154894, upload-time = "2026-01-11T11:21:56.07Z" },
+ { url =
"https://files.pythonhosted.org/packages/e9/05/2f9bf110b5294132b2edf13fe6ca6ae456204f3d749f623307cbb7a946f2/tomli-2.4.0-cp312-cp312-macosx_11_0_arm64.whl",
hash =
"sha256:7d6d9a4aee98fac3eab4952ad1d73aee87359452d1c086b5ceb43ed02ddb16b8", size
= 149053, upload-time = "2026-01-11T11:21:57.467Z" },
+ { url =
"https://files.pythonhosted.org/packages/e8/41/1eda3ca1abc6f6154a8db4d714a4d35c4ad90adc0bcf700657291593fbf3/tomli-2.4.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl",
hash =
"sha256:36b9d05b51e65b254ea6c2585b59d2c4cb91c8a3d91d0ed0f17591a29aaea54a", size
= 243481, upload-time = "2026-01-11T11:21:58.661Z" },
+ { url =
"https://files.pythonhosted.org/packages/d2/6d/02ff5ab6c8868b41e7d4b987ce2b5f6a51d3335a70aa144edd999e055a01/tomli-2.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl",
hash =
"sha256:1c8a885b370751837c029ef9bc014f27d80840e48bac415f3412e6593bbc18c1", size
= 251720, upload-time = "2026-01-11T11:22:00.178Z" },
+ { url =
"https://files.pythonhosted.org/packages/7b/57/0405c59a909c45d5b6f146107c6d997825aa87568b042042f7a9c0afed34/tomli-2.4.0-cp312-cp312-musllinux_1_2_aarch64.whl",
hash =
"sha256:8768715ffc41f0008abe25d808c20c3d990f42b6e2e58305d5da280ae7d1fa3b", size
= 247014, upload-time = "2026-01-11T11:22:01.238Z" },
+ { url =
"https://files.pythonhosted.org/packages/2c/0e/2e37568edd944b4165735687cbaf2fe3648129e440c26d02223672ee0630/tomli-2.4.0-cp312-cp312-musllinux_1_2_x86_64.whl",
hash =
"sha256:7b438885858efd5be02a9a133caf5812b8776ee0c969fea02c45e8e3f296ba51", size
= 251820, upload-time = "2026-01-11T11:22:02.727Z" },
+ { url =
"https://files.pythonhosted.org/packages/5a/1c/ee3b707fdac82aeeb92d1a113f803cf6d0f37bdca0849cb489553e1f417a/tomli-2.4.0-cp312-cp312-win32.whl",
hash =
"sha256:0408e3de5ec77cc7f81960c362543cbbd91ef883e3138e81b729fc3eea5b9729", size
= 97712, upload-time = "2026-01-11T11:22:03.777Z" },
+ { url =
"https://files.pythonhosted.org/packages/69/13/c07a9177d0b3bab7913299b9278845fc6eaaca14a02667c6be0b0a2270c8/tomli-2.4.0-cp312-cp312-win_amd64.whl",
hash =
"sha256:685306e2cc7da35be4ee914fd34ab801a6acacb061b6a7abca922aaf9ad368da", size
= 108296, upload-time = "2026-01-11T11:22:04.86Z" },
+ { url =
"https://files.pythonhosted.org/packages/18/27/e267a60bbeeee343bcc279bb9e8fbed0cbe224bc7b2a3dc2975f22809a09/tomli-2.4.0-cp312-cp312-win_arm64.whl",
hash =
"sha256:5aa48d7c2356055feef06a43611fc401a07337d5b006be13a30f6c58f869e3c3", size
= 94553, upload-time = "2026-01-11T11:22:05.854Z" },
+ { url =
"https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl",
hash =
"sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size
= 14477, upload-time = "2026-01-11T11:22:37.446Z" },
+]
+
[[package]]
name = "tomlkit"
version = "0.13.3"
diff --git a/testing/qdp/test_bindings.py b/testing/qdp/test_bindings.py
index d213d55cd..0e692a366 100644
--- a/testing/qdp/test_bindings.py
+++ b/testing/qdp/test_bindings.py
@@ -1099,3 +1099,196 @@ def test_iqp_encode_errors():
# Non-finite parameter (negative infinity)
with pytest.raises(RuntimeError, match="must be finite"):
engine.encode([float("-inf"), 0.0], 2, "iqp-z")
+
+
+# ==================== IQP FWT Optimization Tests ====================
+
+
[email protected]
+def test_iqp_fwt_normalization():
+ """Test that FWT-optimized IQP produces normalized states (requires
GPU)."""
+ pytest.importorskip("torch")
+ import torch
+ from _qdp import QdpEngine
+
+ if not torch.cuda.is_available():
+ pytest.skip("GPU required for QdpEngine")
+
+ engine = QdpEngine(0)
+
+ # Test across FWT threshold (FWT_MIN_QUBITS = 4)
+ for num_qubits in [3, 4, 5, 6, 7, 8]:
+ # Full IQP (n + n*(n-1)/2 parameters)
+ data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+ data = [0.1 * i for i in range(data_len)]
+
+ qtensor = engine.encode(data, num_qubits, "iqp")
+ torch_tensor = torch.from_dlpack(qtensor)
+
+ # Verify normalization (sum of |amplitude|^2 = 1)
+ norm = torch.sum(torch.abs(torch_tensor) ** 2)
+ assert torch.isclose(norm, torch.tensor(1.0, device="cuda:0"),
atol=1e-6), (
+ f"IQP {num_qubits} qubits not normalized: got {norm.item()}"
+ )
+
+
[email protected]
+def test_iqp_z_fwt_normalization():
+ """Test that FWT-optimized IQP-Z produces normalized states (requires
GPU)."""
+ pytest.importorskip("torch")
+ import torch
+ from _qdp import QdpEngine
+
+ if not torch.cuda.is_available():
+ pytest.skip("GPU required for QdpEngine")
+
+ engine = QdpEngine(0)
+
+ # Test across FWT threshold
+ for num_qubits in [3, 4, 5, 6, 7, 8]:
+ data = [0.2 * i for i in range(num_qubits)]
+
+ qtensor = engine.encode(data, num_qubits, "iqp-z")
+ torch_tensor = torch.from_dlpack(qtensor)
+
+ norm = torch.sum(torch.abs(torch_tensor) ** 2)
+ assert torch.isclose(norm, torch.tensor(1.0, device="cuda:0"),
atol=1e-6), (
+ f"IQP-Z {num_qubits} qubits not normalized: got {norm.item()}"
+ )
+
+
[email protected]
+def test_iqp_fwt_zero_params_gives_zero_state():
+ """Test that zero parameters produce |0...0⟩ state (requires GPU).
+
+ With zero parameters, the IQP circuit is H^n * I * H^n = I,
+ so |0⟩^n maps to |0⟩^n with amplitude 1 at index 0.
+ """
+ pytest.importorskip("torch")
+ import torch
+ from _qdp import QdpEngine
+
+ if not torch.cuda.is_available():
+ pytest.skip("GPU required for QdpEngine")
+
+ engine = QdpEngine(0)
+
+ # Test FWT-optimized path (n >= 4)
+ for num_qubits in [4, 5, 6]:
+ data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+ data = [0.0] * data_len
+
+ qtensor = engine.encode(data, num_qubits, "iqp")
+ torch_tensor = torch.from_dlpack(qtensor)
+
+ # Should get |0...0⟩: amplitude 1 at index 0, 0 elsewhere
+ state_len = 1 << num_qubits
+ expected = torch.zeros(
+ (1, state_len), dtype=torch_tensor.dtype, device="cuda:0"
+ )
+ expected[0, 0] = 1.0 + 0j
+
+ assert torch.allclose(torch_tensor, expected, atol=1e-6), (
+ f"IQP {num_qubits} qubits with zero params should give |0⟩ state"
+ )
+
+
[email protected]
+def test_iqp_fwt_batch_normalization():
+ """Test that FWT-optimized batch IQP produces normalized states (requires
GPU)."""
+ pytest.importorskip("torch")
+ import torch
+ from _qdp import QdpEngine
+
+ if not torch.cuda.is_available():
+ pytest.skip("GPU required for QdpEngine")
+
+ engine = QdpEngine(0)
+
+ # Test batch encoding across FWT threshold
+ for num_qubits in [4, 5, 6]:
+ data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+ batch_size = 8
+
+ data = torch.tensor(
+ [
+ [0.1 * (i + j * data_len) for i in range(data_len)]
+ for j in range(batch_size)
+ ],
+ dtype=torch.float64,
+ )
+
+ qtensor = engine.encode(data, num_qubits, "iqp")
+ torch_tensor = torch.from_dlpack(qtensor)
+
+ assert torch_tensor.shape == (batch_size, 1 << num_qubits)
+
+ # Check each sample is normalized
+ for i in range(batch_size):
+ norm = torch.sum(torch.abs(torch_tensor[i]) ** 2)
+ assert torch.isclose(norm, torch.tensor(1.0, device="cuda:0"),
atol=1e-6), (
+ f"IQP batch sample {i} not normalized: got {norm.item()}"
+ )
+
+
[email protected]
+def test_iqp_fwt_deterministic():
+ """Test that FWT-optimized IQP is deterministic (requires GPU)."""
+ pytest.importorskip("torch")
+ import torch
+ from _qdp import QdpEngine
+
+ if not torch.cuda.is_available():
+ pytest.skip("GPU required for QdpEngine")
+
+ engine = QdpEngine(0)
+
+ num_qubits = 6 # Uses FWT path
+ data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+ data = [0.3 * i for i in range(data_len)]
+
+ # Run encoding twice
+ qtensor1 = engine.encode(data, num_qubits, "iqp")
+ tensor1 = torch.from_dlpack(qtensor1).clone()
+
+ qtensor2 = engine.encode(data, num_qubits, "iqp")
+ tensor2 = torch.from_dlpack(qtensor2)
+
+ # Results should be identical
+ assert torch.allclose(tensor1, tensor2, atol=1e-10), (
+ "IQP FWT encoding should be deterministic"
+ )
+
+
[email protected]
+def test_iqp_fwt_shared_vs_global_memory_threshold():
+ """Test IQP encoding at shared memory threshold boundary (requires GPU).
+
+ FWT_SHARED_MEM_THRESHOLD = 10, so:
+ - n <= 10: uses shared memory FWT
+ - n > 10: uses global memory FWT
+ """
+ pytest.importorskip("torch")
+ import torch
+ from _qdp import QdpEngine
+
+ if not torch.cuda.is_available():
+ pytest.skip("GPU required for QdpEngine")
+
+ engine = QdpEngine(0)
+
+ # Test at and around the shared memory threshold
+ # n <= 10: shared memory FWT, n > 10: global memory FWT (multi-launch)
+ for num_qubits in [9, 10, 11]:
+ data_len = num_qubits + num_qubits * (num_qubits - 1) // 2
+ data = [0.05 * i for i in range(data_len)]
+
+ qtensor = engine.encode(data, num_qubits, "iqp")
+ torch_tensor = torch.from_dlpack(qtensor)
+
+ assert torch_tensor.shape == (1, 1 << num_qubits)
+
+ norm = torch.sum(torch.abs(torch_tensor) ** 2)
+ assert torch.isclose(norm, torch.tensor(1.0, device="cuda:0"),
atol=1e-6), (
+ f"IQP {num_qubits} qubits not normalized at threshold: got
{norm.item()}"
+ )
diff --git a/uv.lock b/uv.lock
index f1124a64d..0f1a55e10 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1997,6 +1997,10 @@ benchmark = [
{ name = "torch", specifier = ">=2.2,<=2.9.0" },
{ name = "tqdm" },
]
+dev = [
+ { name = "pytest" },
+ { name = "torch", specifier = ">=2.2,<=2.9.0" },
+]
[[package]]
name = "requests"