This is an automated email from the ASF dual-hosted git repository.
richhuang pushed a commit to branch dev-qdp
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/dev-qdp by this push:
new 03f71ec41 [QDP] DLPack shape/strides: Support batch 2D tensor (#747)
03f71ec41 is described below
commit 03f71ec41a4f63d0f65b3ea4d636460e37453de7
Author: KUAN-HAO HUANG <[email protected]>
AuthorDate: Wed Dec 24 23:47:32 2025 +0800
[QDP] DLPack shape/strides: Support batch 2D tensor (#747)
---
qdp/benchmark/benchmark_e2e.py | 23 ++++---
qdp/qdp-core/src/dlpack.rs | 24 +++++--
qdp/qdp-core/src/gpu/memory.rs | 5 ++
qdp/qdp-core/src/lib.rs | 30 ++++++--
qdp/qdp-core/src/preprocessing.rs | 6 ++
qdp/qdp-core/tests/api_workflow.rs | 106 +++++++++++++++++++++++++++--
qdp/qdp-core/tests/memory_safety.rs | 34 +++++----
qdp/qdp-core/tests/validation.rs | 24 +++++++
qdp/qdp-python/src/lib.rs | 1 +
qdp/qdp-python/tests/test_bindings.py | 4 +-
qdp/qdp-python/tests/test_high_fidelity.py | 5 +-
11 files changed, 221 insertions(+), 41 deletions(-)
diff --git a/qdp/benchmark/benchmark_e2e.py b/qdp/benchmark/benchmark_e2e.py
index b72c81d1f..0d419d0bf 100644
--- a/qdp/benchmark/benchmark_e2e.py
+++ b/qdp/benchmark/benchmark_e2e.py
@@ -277,15 +277,17 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
dlpack_time = time.perf_counter() - dlpack_start
print(f" DLPack conversion: {dlpack_time:.4f} s")
- # Reshape to [n_samples, state_len] (still complex)
+ # Tensor is already 2D [n_samples, state_len] from to_dlpack()
state_len = 1 << n_qubits
+ assert gpu_batched.shape == (n_samples, state_len), (
+ f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
+ )
# Convert to float for model (batch already on GPU)
reshape_start = time.perf_counter()
- gpu_reshaped = gpu_batched.view(n_samples, state_len)
- gpu_all_data = gpu_reshaped.abs().to(torch.float32)
+ gpu_all_data = gpu_batched.abs().to(torch.float32)
reshape_time = time.perf_counter() - reshape_start
- print(f" Reshape & convert: {reshape_time:.4f} s")
+ print(f" Convert to float32: {reshape_time:.4f} s")
# Forward pass (data already on GPU)
for i in range(0, n_samples, BATCH_SIZE):
@@ -299,7 +301,7 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
# Clean cache after benchmark completion
clean_cache()
- return total_time, gpu_reshaped
+ return total_time, gpu_batched
# -----------------------------------------------------------
@@ -325,13 +327,16 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
dlpack_time = time.perf_counter() - dlpack_start
print(f" DLPack conversion: {dlpack_time:.4f} s")
+ # Tensor is already 2D [n_samples, state_len] from to_dlpack()
state_len = 1 << n_qubits
+ assert gpu_batched.shape == (n_samples, state_len), (
+ f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
+ )
reshape_start = time.perf_counter()
- gpu_reshaped = gpu_batched.view(n_samples, state_len)
- gpu_all_data = gpu_reshaped.abs().to(torch.float32)
+ gpu_all_data = gpu_batched.abs().to(torch.float32)
reshape_time = time.perf_counter() - reshape_start
- print(f" Reshape & convert: {reshape_time:.4f} s")
+ print(f" Convert to float32: {reshape_time:.4f} s")
for i in range(0, n_samples, BATCH_SIZE):
batch = gpu_all_data[i : i + BATCH_SIZE]
@@ -344,7 +349,7 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
# Clean cache after benchmark completion
clean_cache()
- return total_time, gpu_reshaped
+ return total_time, gpu_batched
def compare_states(name_a, states_a, name_b, states_b):
diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index dd134ca5d..e84630ca6 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -120,9 +120,25 @@ impl GpuStateVector {
/// Freed by DLPack deleter when PyTorch releases tensor.
/// Do not free manually.
pub fn to_dlpack(&self) -> *mut DLManagedTensor {
- // Allocate shape/strides on heap (freed by deleter)
- let shape = vec![self.size_elements as i64];
- let strides = vec![1i64];
+ // Always return 2D tensor: Batch [num_samples, state_len], Single [1,
state_len]
+ let (shape, strides) = if let Some(num_samples) = self.num_samples {
+ // Batch: [num_samples, state_len_per_sample]
+ debug_assert!(
+ num_samples > 0 && self.size_elements % num_samples == 0,
+ "Batch state vector size must be divisible by num_samples"
+ );
+ let state_len_per_sample = self.size_elements / num_samples;
+ let shape = vec![num_samples as i64, state_len_per_sample as i64];
+ let strides = vec![state_len_per_sample as i64, 1i64];
+ (shape, strides)
+ } else {
+ // Single: [1, size_elements]
+ let state_len = self.size_elements;
+ let shape = vec![1i64, state_len as i64];
+ let strides = vec![state_len as i64, 1i64];
+ (shape, strides)
+ };
+ let ndim: c_int = 2;
// Transfer ownership to DLPack deleter
let shape_ptr = Box::into_raw(shape.into_boxed_slice()) as *mut i64;
@@ -142,7 +158,7 @@ impl GpuStateVector {
device_type: DLDeviceType::kDLCUDA,
device_id: self.device_id as c_int,
},
- ndim: 1,
+ ndim,
dtype: DLDataType {
code: DL_COMPLEX,
bits: dtype_bits,
diff --git a/qdp/qdp-core/src/gpu/memory.rs b/qdp/qdp-core/src/gpu/memory.rs
index 1cfd32eca..97e3d9cbf 100644
--- a/qdp/qdp-core/src/gpu/memory.rs
+++ b/qdp/qdp-core/src/gpu/memory.rs
@@ -190,6 +190,8 @@ pub struct GpuStateVector {
pub(crate) buffer: Arc<BufferStorage>,
pub num_qubits: usize,
pub size_elements: usize,
+ /// Batch size (None for single state)
+ pub(crate) num_samples: Option<usize>,
pub device_id: usize,
}
@@ -230,6 +232,7 @@ impl GpuStateVector {
buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
num_qubits: qubits,
size_elements: _size_elements,
+ num_samples: None,
device_id: _device.ordinal(),
})
}
@@ -302,6 +305,7 @@ impl GpuStateVector {
buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
num_qubits: qubits,
size_elements: total_elements,
+ num_samples: Some(num_samples),
device_id: _device.ordinal(),
})
}
@@ -367,6 +371,7 @@ impl GpuStateVector {
buffer: Arc::new(BufferStorage::F32(GpuBufferRaw {
slice })),
num_qubits: self.num_qubits,
size_elements: self.size_elements,
+ num_samples: self.num_samples, // Preserve batch
information
device_id: device.ordinal(),
})
}
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index 429813c26..1a1d1b320 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -87,7 +87,7 @@ impl QdpEngine {
/// * `encoding_method` - Strategy: "amplitude", "angle", or "basis"
///
/// # Returns
- /// DLPack pointer for zero-copy PyTorch integration
+ /// DLPack pointer for zero-copy PyTorch integration (shape: [1,
2^num_qubits])
///
/// # Safety
/// Pointer freed by DLPack deleter, do not free manually.
@@ -201,6 +201,27 @@ impl QdpEngine {
if sample_size == 0 {
return Err(MahoutError::InvalidInput("Sample size cannot be
zero".into()));
}
+ if sample_size > STAGE_SIZE_ELEMENTS {
+ return Err(MahoutError::InvalidInput(format!(
+ "Sample size {} exceeds staging buffer capacity {}
(elements)",
+ sample_size, STAGE_SIZE_ELEMENTS
+ )));
+ }
+
+ // Reuse a single norm buffer across chunks to avoid per-chunk
allocations.
+ //
+ // Important: the norm buffer must outlive the async kernels that
consume it.
+ // Per-chunk allocation + drop can lead to use-after-free when the
next chunk
+ // reuses the same device memory while the previous chunk is still
running.
+ let max_samples_per_chunk = std::cmp::max(
+ 1,
+ std::cmp::min(num_samples, STAGE_SIZE_ELEMENTS / sample_size),
+ );
+ let mut norm_buffer =
self.device.alloc_zeros::<f64>(max_samples_per_chunk)
+ .map_err(|e| MahoutError::MemoryAllocation(format!(
+ "Failed to allocate norm buffer: {:?}",
+ e
+ )))?;
full_buf_tx.send(Ok((host_buf_first, first_len)))
.map_err(|_| MahoutError::Io("Failed to send first
buffer".into()))?;
@@ -277,9 +298,10 @@ impl QdpEngine {
let state_ptr_offset =
total_state_vector.ptr_void().cast::<u8>()
.add(offset_bytes)
.cast::<std::ffi::c_void>();
-
- let mut norm_buffer =
self.device.alloc_zeros::<f64>(samples_in_chunk)
- .map_err(|e|
MahoutError::MemoryAllocation(format!("Failed to allocate norm buffer: {:?}",
e)))?;
+ debug_assert!(
+ samples_in_chunk <= max_samples_per_chunk,
+ "samples_in_chunk must be <=
max_samples_per_chunk"
+ );
{
crate::profile_scope!("GPU::NormBatch");
diff --git a/qdp/qdp-core/src/preprocessing.rs
b/qdp/qdp-core/src/preprocessing.rs
index 0d8e70148..43577a8eb 100644
--- a/qdp/qdp-core/src/preprocessing.rs
+++ b/qdp/qdp-core/src/preprocessing.rs
@@ -84,6 +84,12 @@ impl Preprocessor {
sample_size: usize,
num_qubits: usize,
) -> Result<()> {
+ if num_samples == 0 {
+ return Err(MahoutError::InvalidInput(
+ "num_samples must be greater than 0".to_string()
+ ));
+ }
+
if batch_data.len() != num_samples * sample_size {
return Err(MahoutError::InvalidInput(
format!("Batch data length {} doesn't match num_samples {} *
sample_size {}",
diff --git a/qdp/qdp-core/tests/api_workflow.rs
b/qdp/qdp-core/tests/api_workflow.rs
index 13c2126ec..0973d27d5 100644
--- a/qdp/qdp-core/tests/api_workflow.rs
+++ b/qdp/qdp-core/tests/api_workflow.rs
@@ -55,9 +55,7 @@ fn test_amplitude_encoding_workflow() {
println!("Created test data: {} elements", data.len());
let result = engine.encode(&data, 10, "amplitude");
- assert!(result.is_ok(), "Encoding should succeed");
-
- let dlpack_ptr = result.unwrap();
+ let dlpack_ptr = result.expect("Encoding should succeed");
assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
println!("PASS: Encoding succeeded, DLPack pointer valid");
@@ -91,9 +89,7 @@ fn test_amplitude_encoding_async_pipeline() {
println!("Created test data: {} elements", data.len());
let result = engine.encode(&data, 18, "amplitude");
- assert!(result.is_ok(), "Encoding should succeed");
-
- let dlpack_ptr = result.unwrap();
+ let dlpack_ptr = result.expect("Encoding should succeed");
assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
println!("PASS: Encoding succeeded, DLPack pointer valid");
@@ -108,6 +104,104 @@ fn test_amplitude_encoding_async_pipeline() {
}
}
+#[test]
+#[cfg(target_os = "linux")]
+fn test_batch_dlpack_2d_shape() {
+ println!("Testing batch DLPack 2D shape...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => {
+ println!("SKIP: No GPU available");
+ return;
+ }
+ };
+
+ // Create batch data: 3 samples, each with 4 elements (2 qubits)
+ let num_samples = 3;
+ let num_qubits = 2;
+ let sample_size = 4;
+ let batch_data: Vec<f64> = (0..num_samples * sample_size)
+ .map(|i| (i as f64) / 10.0)
+ .collect();
+
+ let result = engine.encode_batch(&batch_data, num_samples, sample_size,
num_qubits, "amplitude");
+ let dlpack_ptr = result.expect("Batch encoding should succeed");
+ assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+ unsafe {
+ let managed = &*dlpack_ptr;
+ let tensor = &managed.dl_tensor;
+
+ // Verify 2D shape for batch tensor
+ assert_eq!(tensor.ndim, 2, "Batch tensor should be 2D");
+
+ let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim
as usize);
+ assert_eq!(shape_slice[0], num_samples as i64, "First dimension should
be num_samples");
+ assert_eq!(shape_slice[1], (1 << num_qubits) as i64, "Second dimension
should be 2^num_qubits");
+
+ let strides_slice = std::slice::from_raw_parts(tensor.strides,
tensor.ndim as usize);
+ let state_len = 1 << num_qubits;
+ assert_eq!(strides_slice[0], state_len as i64, "Stride for first
dimension should be state_len");
+ assert_eq!(strides_slice[1], 1, "Stride for second dimension should be
1");
+
+ println!("PASS: Batch DLPack tensor has correct 2D shape: [{}, {}]",
shape_slice[0], shape_slice[1]);
+ println!("PASS: Strides are correct: [{}, {}]", strides_slice[0],
strides_slice[1]);
+
+ // Free memory
+ if let Some(deleter) = managed.deleter {
+ deleter(dlpack_ptr);
+ }
+ }
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_single_encode_dlpack_2d_shape() {
+ println!("Testing single encode returns 2D shape...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => {
+ println!("SKIP: No GPU available");
+ return;
+ }
+ };
+
+ let data = common::create_test_data(16);
+ let result = engine.encode(&data, 4, "amplitude");
+ assert!(result.is_ok(), "Encoding should succeed");
+
+ let dlpack_ptr = result.unwrap();
+ assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+ unsafe {
+ let managed = &*dlpack_ptr;
+ let tensor = &managed.dl_tensor;
+
+ // Verify 2D shape for single encode: [1, 2^num_qubits]
+ assert_eq!(tensor.ndim, 2, "Single encode should be 2D");
+
+ let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim
as usize);
+ assert_eq!(shape_slice[0], 1, "First dimension should be 1 for single
encode");
+ assert_eq!(shape_slice[1], 16, "Second dimension should be [2^4]");
+
+ let strides_slice = std::slice::from_raw_parts(tensor.strides,
tensor.ndim as usize);
+ assert_eq!(strides_slice[0], 16, "Stride for first dimension should be
state_len");
+ assert_eq!(strides_slice[1], 1, "Stride for second dimension should be
1");
+
+ println!(
+ "PASS: Single encode returns 2D shape: [{}, {}]",
+ shape_slice[0], shape_slice[1]
+ );
+
+ // Free memory
+ if let Some(deleter) = managed.deleter {
+ deleter(dlpack_ptr);
+ }
+ }
+}
+
#[test]
#[cfg(target_os = "linux")]
fn test_dlpack_device_id() {
diff --git a/qdp/qdp-core/tests/memory_safety.rs
b/qdp/qdp-core/tests/memory_safety.rs
index 6aa2d355a..7084c071f 100644
--- a/qdp/qdp-core/tests/memory_safety.rs
+++ b/qdp/qdp-core/tests/memory_safety.rs
@@ -106,24 +106,26 @@ fn test_dlpack_tensor_metadata_default() {
let managed = &mut *ptr;
let tensor = &managed.dl_tensor;
- assert_eq!(tensor.ndim, 1, "Should be 1D tensor");
+ assert_eq!(tensor.ndim, 2, "Should be 2D tensor");
assert!(!tensor.data.is_null(), "Data pointer should be valid");
assert!(!tensor.shape.is_null(), "Shape pointer should be valid");
assert!(!tensor.strides.is_null(), "Strides pointer should be valid");
- let shape = *tensor.shape;
- assert_eq!(shape, 1024, "Shape should be 1024 (2^10)");
+ let shape = std::slice::from_raw_parts(tensor.shape, tensor.ndim as
usize);
+ assert_eq!(shape[0], 1, "First dimension should be 1 for single
encode");
+ assert_eq!(shape[1], 1024, "Second dimension should be 1024 (2^10)");
- let stride = *tensor.strides;
- assert_eq!(stride, 1, "Stride for 1D contiguous array should be 1");
+ let strides = std::slice::from_raw_parts(tensor.strides, tensor.ndim
as usize);
+ assert_eq!(strides[0], 1024, "Stride for first dimension should be
state_len");
+ assert_eq!(strides[1], 1, "Stride for second dimension should be 1");
assert_eq!(tensor.dtype.code, 5, "Should be complex type (code=5)");
- assert_eq!(tensor.dtype.bits, 64, "Should be 64 bits (2x32-bit
floats)");
+ assert_eq!(tensor.dtype.bits, 128, "Should be 128 bits (2x64-bit
floats, Float64)");
println!("PASS: DLPack metadata verified");
println!(" ndim: {}", tensor.ndim);
- println!(" shape: {}", shape);
- println!(" stride: {}", stride);
+ println!(" shape: [{}, {}]", shape[0], shape[1]);
+ println!(" strides: [{}, {}]", strides[0], strides[1]);
println!(
" dtype: code={}, bits={}",
tensor.dtype.code, tensor.dtype.bits
@@ -154,16 +156,18 @@ fn test_dlpack_tensor_metadata_f64() {
let managed = &mut *ptr;
let tensor = &managed.dl_tensor;
- assert_eq!(tensor.ndim, 1, "Should be 1D tensor");
+ assert_eq!(tensor.ndim, 2, "Should be 2D tensor");
assert!(!tensor.data.is_null(), "Data pointer should be valid");
assert!(!tensor.shape.is_null(), "Shape pointer should be valid");
assert!(!tensor.strides.is_null(), "Strides pointer should be valid");
- let shape = *tensor.shape;
- assert_eq!(shape, 1024, "Shape should be 1024 (2^10)");
+ let shape = std::slice::from_raw_parts(tensor.shape, tensor.ndim as
usize);
+ assert_eq!(shape[0], 1, "First dimension should be 1 for single
encode");
+ assert_eq!(shape[1], 1024, "Second dimension should be 1024 (2^10)");
- let stride = *tensor.strides;
- assert_eq!(stride, 1, "Stride for 1D contiguous array should be 1");
+ let strides = std::slice::from_raw_parts(tensor.strides, tensor.ndim
as usize);
+ assert_eq!(strides[0], 1024, "Stride for first dimension should be
state_len");
+ assert_eq!(strides[1], 1, "Stride for second dimension should be 1");
assert_eq!(tensor.dtype.code, 5, "Should be complex type (code=5)");
assert_eq!(
@@ -173,8 +177,8 @@ fn test_dlpack_tensor_metadata_f64() {
println!("PASS: DLPack metadata verified");
println!(" ndim: {}", tensor.ndim);
- println!(" shape: {}", shape);
- println!(" stride: {}", stride);
+ println!(" shape: [{}, {}]", shape[0], shape[1]);
+ println!(" strides: [{}, {}]", strides[0], strides[1]);
println!(
" dtype: code={}, bits={}",
tensor.dtype.code, tensor.dtype.bits
diff --git a/qdp/qdp-core/tests/validation.rs b/qdp/qdp-core/tests/validation.rs
index cc12a995a..6fc591e53 100644
--- a/qdp/qdp-core/tests/validation.rs
+++ b/qdp/qdp-core/tests/validation.rs
@@ -119,6 +119,30 @@ fn test_input_validation_max_qubits() {
}
}
+#[test]
+#[cfg(target_os = "linux")]
+fn test_input_validation_batch_zero_samples() {
+ println!("Testing zero num_samples rejection...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => return,
+ };
+
+ let batch_data = vec![1.0, 2.0, 3.0, 4.0];
+ let result = engine.encode_batch(&batch_data, 0, 4, 2, "amplitude");
+ assert!(result.is_err(), "Should reject zero num_samples");
+
+ match result {
+ Err(MahoutError::InvalidInput(msg)) => {
+ assert!(msg.contains("num_samples must be greater than 0"),
+ "Error should mention num_samples requirement");
+ println!("PASS: Correctly rejected zero num_samples: {}", msg);
+ }
+ _ => panic!("Expected InvalidInput error for zero num_samples"),
+ }
+}
+
#[test]
#[cfg(target_os = "linux")]
fn test_empty_data() {
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index d94aceeb2..642a9a7fa 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -188,6 +188,7 @@ impl QdpEngine {
///
/// Returns:
/// QuantumTensor: DLPack-compatible tensor for zero-copy PyTorch
integration
+ /// Shape: [1, 2^num_qubits]
///
/// Raises:
/// RuntimeError: If encoding fails
diff --git a/qdp/qdp-python/tests/test_bindings.py
b/qdp/qdp-python/tests/test_bindings.py
index d3cda3e22..7808abc8c 100644
--- a/qdp/qdp-python/tests/test_bindings.py
+++ b/qdp/qdp-python/tests/test_bindings.py
@@ -126,8 +126,8 @@ def test_pytorch_integration():
assert torch_tensor.device.index == 0
assert torch_tensor.dtype == torch.complex64
- # Verify shape (2 qubits = 2^2 = 4 elements)
- assert torch_tensor.shape == (4,)
+ # Verify shape (2 qubits = 2^2 = 4 elements) as 2D for consistency: [1, 4]
+ assert torch_tensor.shape == (1, 4)
@pytest.mark.gpu
diff --git a/qdp/qdp-python/tests/test_high_fidelity.py
b/qdp/qdp-python/tests/test_high_fidelity.py
index 24f11c513..9046272cb 100644
--- a/qdp/qdp-python/tests/test_high_fidelity.py
+++ b/qdp/qdp-python/tests/test_high_fidelity.py
@@ -36,6 +36,9 @@ def calculate_fidelity(
) -> float:
"""Calculate quantum state fidelity: F = |<ψ_gpu | ψ_cpu>|²"""
psi_gpu = state_vector_gpu.cpu().numpy()
+ # Convert 2D [1, state_len] to 1D for compatibility with ground truth
+ if psi_gpu.ndim == 2 and psi_gpu.shape[0] == 1:
+ psi_gpu = psi_gpu[0]
if np.any(np.isnan(psi_gpu)) or np.any(np.isinf(psi_gpu)):
return 0.0
@@ -103,7 +106,7 @@ def test_amplitude_encoding_fidelity_comprehensive(
assert torch_state.is_cuda, "Tensor must be on GPU"
assert torch_state.dtype == torch.complex128, "Tensor must be Complex128"
- assert torch_state.shape[0] == state_len, "Tensor shape must match 2^n"
+ assert torch_state.shape == (1, state_len), "Tensor shape must be [1, 2^n]"
fidelity = calculate_fidelity(torch_state, expected_state_complex)
print(f"Fidelity: {fidelity:.16f}")