This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git

commit dab5577cd76a33381488a6ec5f8d9a9f58e3c068
Author: KUAN-HAO HUANG <[email protected]>
AuthorDate: Wed Dec 24 23:47:32 2025 +0800

    [QDP] DLPack shape/strides: Support batch 2D tensor (#747)
---
 qdp/benchmark/benchmark_e2e.py             |  23 ++++---
 qdp/qdp-core/src/dlpack.rs                 |  24 +++++--
 qdp/qdp-core/src/gpu/memory.rs             |   5 ++
 qdp/qdp-core/src/lib.rs                    |  30 ++++++--
 qdp/qdp-core/src/preprocessing.rs          |   6 ++
 qdp/qdp-core/tests/api_workflow.rs         | 106 +++++++++++++++++++++++++++--
 qdp/qdp-core/tests/memory_safety.rs        |  34 +++++----
 qdp/qdp-core/tests/validation.rs           |  24 +++++++
 qdp/qdp-python/src/lib.rs                  |   1 +
 qdp/qdp-python/tests/test_bindings.py      |   4 +-
 qdp/qdp-python/tests/test_high_fidelity.py |   5 +-
 11 files changed, 221 insertions(+), 41 deletions(-)

diff --git a/qdp/benchmark/benchmark_e2e.py b/qdp/benchmark/benchmark_e2e.py
index b72c81d1f..0d419d0bf 100644
--- a/qdp/benchmark/benchmark_e2e.py
+++ b/qdp/benchmark/benchmark_e2e.py
@@ -277,15 +277,17 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
     dlpack_time = time.perf_counter() - dlpack_start
     print(f"  DLPack conversion: {dlpack_time:.4f} s")
 
-    # Reshape to [n_samples, state_len] (still complex)
+    # Tensor is already 2D [n_samples, state_len] from to_dlpack()
     state_len = 1 << n_qubits
+    assert gpu_batched.shape == (n_samples, state_len), (
+        f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
+    )
 
     # Convert to float for model (batch already on GPU)
     reshape_start = time.perf_counter()
-    gpu_reshaped = gpu_batched.view(n_samples, state_len)
-    gpu_all_data = gpu_reshaped.abs().to(torch.float32)
+    gpu_all_data = gpu_batched.abs().to(torch.float32)
     reshape_time = time.perf_counter() - reshape_start
-    print(f"  Reshape & convert: {reshape_time:.4f} s")
+    print(f"  Convert to float32: {reshape_time:.4f} s")
 
     # Forward pass (data already on GPU)
     for i in range(0, n_samples, BATCH_SIZE):
@@ -299,7 +301,7 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
     # Clean cache after benchmark completion
     clean_cache()
 
-    return total_time, gpu_reshaped
+    return total_time, gpu_batched
 
 
 # -----------------------------------------------------------
@@ -325,13 +327,16 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
     dlpack_time = time.perf_counter() - dlpack_start
     print(f"  DLPack conversion: {dlpack_time:.4f} s")
 
+    # Tensor is already 2D [n_samples, state_len] from to_dlpack()
     state_len = 1 << n_qubits
+    assert gpu_batched.shape == (n_samples, state_len), (
+        f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
+    )
 
     reshape_start = time.perf_counter()
-    gpu_reshaped = gpu_batched.view(n_samples, state_len)
-    gpu_all_data = gpu_reshaped.abs().to(torch.float32)
+    gpu_all_data = gpu_batched.abs().to(torch.float32)
     reshape_time = time.perf_counter() - reshape_start
-    print(f"  Reshape & convert: {reshape_time:.4f} s")
+    print(f"  Convert to float32: {reshape_time:.4f} s")
 
     for i in range(0, n_samples, BATCH_SIZE):
         batch = gpu_all_data[i : i + BATCH_SIZE]
@@ -344,7 +349,7 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
     # Clean cache after benchmark completion
     clean_cache()
 
-    return total_time, gpu_reshaped
+    return total_time, gpu_batched
 
 
 def compare_states(name_a, states_a, name_b, states_b):
diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index dd134ca5d..e84630ca6 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -120,9 +120,25 @@ impl GpuStateVector {
     /// Freed by DLPack deleter when PyTorch releases tensor.
     /// Do not free manually.
     pub fn to_dlpack(&self) -> *mut DLManagedTensor {
-        // Allocate shape/strides on heap (freed by deleter)
-        let shape = vec![self.size_elements as i64];
-        let strides = vec![1i64];
+        // Always return 2D tensor: Batch [num_samples, state_len], Single [1, 
state_len]
+        let (shape, strides) = if let Some(num_samples) = self.num_samples {
+            // Batch: [num_samples, state_len_per_sample]
+            debug_assert!(
+                num_samples > 0 && self.size_elements % num_samples == 0,
+                "Batch state vector size must be divisible by num_samples"
+            );
+            let state_len_per_sample = self.size_elements / num_samples;
+            let shape = vec![num_samples as i64, state_len_per_sample as i64];
+            let strides = vec![state_len_per_sample as i64, 1i64];
+            (shape, strides)
+        } else {
+            // Single: [1, size_elements]
+            let state_len = self.size_elements;
+            let shape = vec![1i64, state_len as i64];
+            let strides = vec![state_len as i64, 1i64];
+            (shape, strides)
+        };
+        let ndim: c_int = 2;
 
         // Transfer ownership to DLPack deleter
         let shape_ptr = Box::into_raw(shape.into_boxed_slice()) as *mut i64;
@@ -142,7 +158,7 @@ impl GpuStateVector {
                 device_type: DLDeviceType::kDLCUDA,
                 device_id: self.device_id as c_int,
             },
-            ndim: 1,
+            ndim,
             dtype: DLDataType {
                 code: DL_COMPLEX,
                 bits: dtype_bits,
diff --git a/qdp/qdp-core/src/gpu/memory.rs b/qdp/qdp-core/src/gpu/memory.rs
index 1cfd32eca..97e3d9cbf 100644
--- a/qdp/qdp-core/src/gpu/memory.rs
+++ b/qdp/qdp-core/src/gpu/memory.rs
@@ -190,6 +190,8 @@ pub struct GpuStateVector {
     pub(crate) buffer: Arc<BufferStorage>,
     pub num_qubits: usize,
     pub size_elements: usize,
+    /// Batch size (None for single state)
+    pub(crate) num_samples: Option<usize>,
     pub device_id: usize,
 }
 
@@ -230,6 +232,7 @@ impl GpuStateVector {
                 buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
                 num_qubits: qubits,
                 size_elements: _size_elements,
+                num_samples: None,
                 device_id: _device.ordinal(),
             })
         }
@@ -302,6 +305,7 @@ impl GpuStateVector {
                 buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
                 num_qubits: qubits,
                 size_elements: total_elements,
+                num_samples: Some(num_samples),
                 device_id: _device.ordinal(),
             })
         }
@@ -367,6 +371,7 @@ impl GpuStateVector {
                         buffer: Arc::new(BufferStorage::F32(GpuBufferRaw { 
slice })),
                         num_qubits: self.num_qubits,
                         size_elements: self.size_elements,
+                        num_samples: self.num_samples, // Preserve batch 
information
                         device_id: device.ordinal(),
                     })
                 }
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index 429813c26..1a1d1b320 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -87,7 +87,7 @@ impl QdpEngine {
     /// * `encoding_method` - Strategy: "amplitude", "angle", or "basis"
     ///
     /// # Returns
-    /// DLPack pointer for zero-copy PyTorch integration
+    /// DLPack pointer for zero-copy PyTorch integration (shape: [1, 
2^num_qubits])
     ///
     /// # Safety
     /// Pointer freed by DLPack deleter, do not free manually.
@@ -201,6 +201,27 @@ impl QdpEngine {
             if sample_size == 0 {
                 return Err(MahoutError::InvalidInput("Sample size cannot be 
zero".into()));
             }
+            if sample_size > STAGE_SIZE_ELEMENTS {
+                return Err(MahoutError::InvalidInput(format!(
+                    "Sample size {} exceeds staging buffer capacity {} 
(elements)",
+                    sample_size, STAGE_SIZE_ELEMENTS
+                )));
+            }
+
+            // Reuse a single norm buffer across chunks to avoid per-chunk 
allocations.
+            //
+            // Important: the norm buffer must outlive the async kernels that 
consume it.
+            // Per-chunk allocation + drop can lead to use-after-free when the 
next chunk
+            // reuses the same device memory while the previous chunk is still 
running.
+            let max_samples_per_chunk = std::cmp::max(
+                1,
+                std::cmp::min(num_samples, STAGE_SIZE_ELEMENTS / sample_size),
+            );
+            let mut norm_buffer = 
self.device.alloc_zeros::<f64>(max_samples_per_chunk)
+                .map_err(|e| MahoutError::MemoryAllocation(format!(
+                    "Failed to allocate norm buffer: {:?}",
+                    e
+                )))?;
 
             full_buf_tx.send(Ok((host_buf_first, first_len)))
                 .map_err(|_| MahoutError::Io("Failed to send first 
buffer".into()))?;
@@ -277,9 +298,10 @@ impl QdpEngine {
                             let state_ptr_offset = 
total_state_vector.ptr_void().cast::<u8>()
                                 .add(offset_bytes)
                                 .cast::<std::ffi::c_void>();
-
-                            let mut norm_buffer = 
self.device.alloc_zeros::<f64>(samples_in_chunk)
-                                .map_err(|e| 
MahoutError::MemoryAllocation(format!("Failed to allocate norm buffer: {:?}", 
e)))?;
+                            debug_assert!(
+                                samples_in_chunk <= max_samples_per_chunk,
+                                "samples_in_chunk must be <= 
max_samples_per_chunk"
+                            );
 
                             {
                                 crate::profile_scope!("GPU::NormBatch");
diff --git a/qdp/qdp-core/src/preprocessing.rs 
b/qdp/qdp-core/src/preprocessing.rs
index 0d8e70148..43577a8eb 100644
--- a/qdp/qdp-core/src/preprocessing.rs
+++ b/qdp/qdp-core/src/preprocessing.rs
@@ -84,6 +84,12 @@ impl Preprocessor {
         sample_size: usize,
         num_qubits: usize,
     ) -> Result<()> {
+        if num_samples == 0 {
+            return Err(MahoutError::InvalidInput(
+                "num_samples must be greater than 0".to_string()
+            ));
+        }
+
         if batch_data.len() != num_samples * sample_size {
             return Err(MahoutError::InvalidInput(
                 format!("Batch data length {} doesn't match num_samples {} * 
sample_size {}",
diff --git a/qdp/qdp-core/tests/api_workflow.rs 
b/qdp/qdp-core/tests/api_workflow.rs
index 13c2126ec..0973d27d5 100644
--- a/qdp/qdp-core/tests/api_workflow.rs
+++ b/qdp/qdp-core/tests/api_workflow.rs
@@ -55,9 +55,7 @@ fn test_amplitude_encoding_workflow() {
     println!("Created test data: {} elements", data.len());
 
     let result = engine.encode(&data, 10, "amplitude");
-    assert!(result.is_ok(), "Encoding should succeed");
-
-    let dlpack_ptr = result.unwrap();
+    let dlpack_ptr = result.expect("Encoding should succeed");
     assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
     println!("PASS: Encoding succeeded, DLPack pointer valid");
 
@@ -91,9 +89,7 @@ fn test_amplitude_encoding_async_pipeline() {
     println!("Created test data: {} elements", data.len());
 
     let result = engine.encode(&data, 18, "amplitude");
-    assert!(result.is_ok(), "Encoding should succeed");
-
-    let dlpack_ptr = result.unwrap();
+    let dlpack_ptr = result.expect("Encoding should succeed");
     assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
     println!("PASS: Encoding succeeded, DLPack pointer valid");
 
@@ -108,6 +104,104 @@ fn test_amplitude_encoding_async_pipeline() {
     }
 }
 
+#[test]
+#[cfg(target_os = "linux")]
+fn test_batch_dlpack_2d_shape() {
+    println!("Testing batch DLPack 2D shape...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    // Create batch data: 3 samples, each with 4 elements (2 qubits)
+    let num_samples = 3;
+    let num_qubits = 2;
+    let sample_size = 4;
+    let batch_data: Vec<f64> = (0..num_samples * sample_size)
+        .map(|i| (i as f64) / 10.0)
+        .collect();
+
+    let result = engine.encode_batch(&batch_data, num_samples, sample_size, 
num_qubits, "amplitude");
+    let dlpack_ptr = result.expect("Batch encoding should succeed");
+    assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+    unsafe {
+        let managed = &*dlpack_ptr;
+        let tensor = &managed.dl_tensor;
+
+        // Verify 2D shape for batch tensor
+        assert_eq!(tensor.ndim, 2, "Batch tensor should be 2D");
+
+        let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim 
as usize);
+        assert_eq!(shape_slice[0], num_samples as i64, "First dimension should 
be num_samples");
+        assert_eq!(shape_slice[1], (1 << num_qubits) as i64, "Second dimension 
should be 2^num_qubits");
+
+        let strides_slice = std::slice::from_raw_parts(tensor.strides, 
tensor.ndim as usize);
+        let state_len = 1 << num_qubits;
+        assert_eq!(strides_slice[0], state_len as i64, "Stride for first 
dimension should be state_len");
+        assert_eq!(strides_slice[1], 1, "Stride for second dimension should be 
1");
+
+        println!("PASS: Batch DLPack tensor has correct 2D shape: [{}, {}]", 
shape_slice[0], shape_slice[1]);
+        println!("PASS: Strides are correct: [{}, {}]", strides_slice[0], 
strides_slice[1]);
+
+        // Free memory
+        if let Some(deleter) = managed.deleter {
+            deleter(dlpack_ptr);
+        }
+    }
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_single_encode_dlpack_2d_shape() {
+    println!("Testing single encode returns 2D shape...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    let data = common::create_test_data(16);
+    let result = engine.encode(&data, 4, "amplitude");
+    assert!(result.is_ok(), "Encoding should succeed");
+
+    let dlpack_ptr = result.unwrap();
+    assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+    unsafe {
+        let managed = &*dlpack_ptr;
+        let tensor = &managed.dl_tensor;
+
+        // Verify 2D shape for single encode: [1, 2^num_qubits]
+        assert_eq!(tensor.ndim, 2, "Single encode should be 2D");
+
+        let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim 
as usize);
+        assert_eq!(shape_slice[0], 1, "First dimension should be 1 for single 
encode");
+        assert_eq!(shape_slice[1], 16, "Second dimension should be [2^4]");
+
+        let strides_slice = std::slice::from_raw_parts(tensor.strides, 
tensor.ndim as usize);
+        assert_eq!(strides_slice[0], 16, "Stride for first dimension should be 
state_len");
+        assert_eq!(strides_slice[1], 1, "Stride for second dimension should be 
1");
+
+        println!(
+            "PASS: Single encode returns 2D shape: [{}, {}]",
+            shape_slice[0], shape_slice[1]
+        );
+
+        // Free memory
+        if let Some(deleter) = managed.deleter {
+            deleter(dlpack_ptr);
+        }
+    }
+}
+
 #[test]
 #[cfg(target_os = "linux")]
 fn test_dlpack_device_id() {
diff --git a/qdp/qdp-core/tests/memory_safety.rs 
b/qdp/qdp-core/tests/memory_safety.rs
index 6aa2d355a..7084c071f 100644
--- a/qdp/qdp-core/tests/memory_safety.rs
+++ b/qdp/qdp-core/tests/memory_safety.rs
@@ -106,24 +106,26 @@ fn test_dlpack_tensor_metadata_default() {
         let managed = &mut *ptr;
         let tensor = &managed.dl_tensor;
 
-        assert_eq!(tensor.ndim, 1, "Should be 1D tensor");
+        assert_eq!(tensor.ndim, 2, "Should be 2D tensor");
         assert!(!tensor.data.is_null(), "Data pointer should be valid");
         assert!(!tensor.shape.is_null(), "Shape pointer should be valid");
         assert!(!tensor.strides.is_null(), "Strides pointer should be valid");
 
-        let shape = *tensor.shape;
-        assert_eq!(shape, 1024, "Shape should be 1024 (2^10)");
+        let shape = std::slice::from_raw_parts(tensor.shape, tensor.ndim as 
usize);
+        assert_eq!(shape[0], 1, "First dimension should be 1 for single 
encode");
+        assert_eq!(shape[1], 1024, "Second dimension should be 1024 (2^10)");
 
-        let stride = *tensor.strides;
-        assert_eq!(stride, 1, "Stride for 1D contiguous array should be 1");
+        let strides = std::slice::from_raw_parts(tensor.strides, tensor.ndim 
as usize);
+        assert_eq!(strides[0], 1024, "Stride for first dimension should be 
state_len");
+        assert_eq!(strides[1], 1, "Stride for second dimension should be 1");
 
         assert_eq!(tensor.dtype.code, 5, "Should be complex type (code=5)");
-        assert_eq!(tensor.dtype.bits, 64, "Should be 64 bits (2x32-bit 
floats)");
+        assert_eq!(tensor.dtype.bits, 128, "Should be 128 bits (2x64-bit 
floats, Float64)");
 
         println!("PASS: DLPack metadata verified");
         println!("  ndim: {}", tensor.ndim);
-        println!("  shape: {}", shape);
-        println!("  stride: {}", stride);
+        println!("  shape: [{}, {}]", shape[0], shape[1]);
+        println!("  strides: [{}, {}]", strides[0], strides[1]);
         println!(
             "  dtype: code={}, bits={}",
             tensor.dtype.code, tensor.dtype.bits
@@ -154,16 +156,18 @@ fn test_dlpack_tensor_metadata_f64() {
         let managed = &mut *ptr;
         let tensor = &managed.dl_tensor;
 
-        assert_eq!(tensor.ndim, 1, "Should be 1D tensor");
+        assert_eq!(tensor.ndim, 2, "Should be 2D tensor");
         assert!(!tensor.data.is_null(), "Data pointer should be valid");
         assert!(!tensor.shape.is_null(), "Shape pointer should be valid");
         assert!(!tensor.strides.is_null(), "Strides pointer should be valid");
 
-        let shape = *tensor.shape;
-        assert_eq!(shape, 1024, "Shape should be 1024 (2^10)");
+        let shape = std::slice::from_raw_parts(tensor.shape, tensor.ndim as 
usize);
+        assert_eq!(shape[0], 1, "First dimension should be 1 for single 
encode");
+        assert_eq!(shape[1], 1024, "Second dimension should be 1024 (2^10)");
 
-        let stride = *tensor.strides;
-        assert_eq!(stride, 1, "Stride for 1D contiguous array should be 1");
+        let strides = std::slice::from_raw_parts(tensor.strides, tensor.ndim 
as usize);
+        assert_eq!(strides[0], 1024, "Stride for first dimension should be 
state_len");
+        assert_eq!(strides[1], 1, "Stride for second dimension should be 1");
 
         assert_eq!(tensor.dtype.code, 5, "Should be complex type (code=5)");
         assert_eq!(
@@ -173,8 +177,8 @@ fn test_dlpack_tensor_metadata_f64() {
 
         println!("PASS: DLPack metadata verified");
         println!("  ndim: {}", tensor.ndim);
-        println!("  shape: {}", shape);
-        println!("  stride: {}", stride);
+        println!("  shape: [{}, {}]", shape[0], shape[1]);
+        println!("  strides: [{}, {}]", strides[0], strides[1]);
         println!(
             "  dtype: code={}, bits={}",
             tensor.dtype.code, tensor.dtype.bits
diff --git a/qdp/qdp-core/tests/validation.rs b/qdp/qdp-core/tests/validation.rs
index cc12a995a..6fc591e53 100644
--- a/qdp/qdp-core/tests/validation.rs
+++ b/qdp/qdp-core/tests/validation.rs
@@ -119,6 +119,30 @@ fn test_input_validation_max_qubits() {
     }
 }
 
+#[test]
+#[cfg(target_os = "linux")]
+fn test_input_validation_batch_zero_samples() {
+    println!("Testing zero num_samples rejection...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => return,
+    };
+
+    let batch_data = vec![1.0, 2.0, 3.0, 4.0];
+    let result = engine.encode_batch(&batch_data, 0, 4, 2, "amplitude");
+    assert!(result.is_err(), "Should reject zero num_samples");
+
+    match result {
+        Err(MahoutError::InvalidInput(msg)) => {
+            assert!(msg.contains("num_samples must be greater than 0"),
+                    "Error should mention num_samples requirement");
+            println!("PASS: Correctly rejected zero num_samples: {}", msg);
+        }
+        _ => panic!("Expected InvalidInput error for zero num_samples"),
+    }
+}
+
 #[test]
 #[cfg(target_os = "linux")]
 fn test_empty_data() {
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index d94aceeb2..642a9a7fa 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -188,6 +188,7 @@ impl QdpEngine {
     ///
     /// Returns:
     ///     QuantumTensor: DLPack-compatible tensor for zero-copy PyTorch 
integration
+    ///         Shape: [1, 2^num_qubits]
     ///
     /// Raises:
     ///     RuntimeError: If encoding fails
diff --git a/qdp/qdp-python/tests/test_bindings.py 
b/qdp/qdp-python/tests/test_bindings.py
index d3cda3e22..7808abc8c 100644
--- a/qdp/qdp-python/tests/test_bindings.py
+++ b/qdp/qdp-python/tests/test_bindings.py
@@ -126,8 +126,8 @@ def test_pytorch_integration():
     assert torch_tensor.device.index == 0
     assert torch_tensor.dtype == torch.complex64
 
-    # Verify shape (2 qubits = 2^2 = 4 elements)
-    assert torch_tensor.shape == (4,)
+    # Verify shape (2 qubits = 2^2 = 4 elements) as 2D for consistency: [1, 4]
+    assert torch_tensor.shape == (1, 4)
 
 
 @pytest.mark.gpu
diff --git a/qdp/qdp-python/tests/test_high_fidelity.py 
b/qdp/qdp-python/tests/test_high_fidelity.py
index 24f11c513..9046272cb 100644
--- a/qdp/qdp-python/tests/test_high_fidelity.py
+++ b/qdp/qdp-python/tests/test_high_fidelity.py
@@ -36,6 +36,9 @@ def calculate_fidelity(
 ) -> float:
     """Calculate quantum state fidelity: F = |<ψ_gpu | ψ_cpu>|²"""
     psi_gpu = state_vector_gpu.cpu().numpy()
+    # Convert 2D [1, state_len] to 1D for compatibility with ground truth
+    if psi_gpu.ndim == 2 and psi_gpu.shape[0] == 1:
+        psi_gpu = psi_gpu[0]
 
     if np.any(np.isnan(psi_gpu)) or np.any(np.isinf(psi_gpu)):
         return 0.0
@@ -103,7 +106,7 @@ def test_amplitude_encoding_fidelity_comprehensive(
 
     assert torch_state.is_cuda, "Tensor must be on GPU"
     assert torch_state.dtype == torch.complex128, "Tensor must be Complex128"
-    assert torch_state.shape[0] == state_len, "Tensor shape must match 2^n"
+    assert torch_state.shape == (1, state_len), "Tensor shape must be [1, 2^n]"
 
     fidelity = calculate_fidelity(torch_state, expected_state_complex)
     print(f"Fidelity: {fidelity:.16f}")

Reply via email to