This is an automated email from the ASF dual-hosted git repository.

richhuang pushed a commit to branch dev-qdp
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/dev-qdp by this push:
     new b6f1a3bc1 [QDP] DLPack shape/strides: Support batch 2D tensor
b6f1a3bc1 is described below

commit b6f1a3bc1b4f5c933407cb26da8e46ab7ee2f01f
Author: rich7420 <[email protected]>
AuthorDate: Fri Dec 19 16:25:15 2025 +0800

    [QDP] DLPack shape/strides: Support batch 2D tensor
---
 qdp/benchmark/benchmark_e2e.py      | 23 ++++++----
 qdp/qdp-core/src/dlpack.rs          | 17 +++++--
 qdp/qdp-core/src/gpu/memory.rs      |  5 ++
 qdp/qdp-core/src/preprocessing.rs   |  6 +++
 qdp/qdp-core/tests/api_workflow.rs  | 92 +++++++++++++++++++++++++++++++++++++
 qdp/qdp-core/tests/memory_safety.rs |  2 +-
 qdp/qdp-core/tests/validation.rs    | 24 ++++++++++
 7 files changed, 155 insertions(+), 14 deletions(-)

diff --git a/qdp/benchmark/benchmark_e2e.py b/qdp/benchmark/benchmark_e2e.py
index b72c81d1f..0d419d0bf 100644
--- a/qdp/benchmark/benchmark_e2e.py
+++ b/qdp/benchmark/benchmark_e2e.py
@@ -277,15 +277,17 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
     dlpack_time = time.perf_counter() - dlpack_start
     print(f"  DLPack conversion: {dlpack_time:.4f} s")
 
-    # Reshape to [n_samples, state_len] (still complex)
+    # Tensor is already 2D [n_samples, state_len] from to_dlpack()
     state_len = 1 << n_qubits
+    assert gpu_batched.shape == (n_samples, state_len), (
+        f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
+    )
 
     # Convert to float for model (batch already on GPU)
     reshape_start = time.perf_counter()
-    gpu_reshaped = gpu_batched.view(n_samples, state_len)
-    gpu_all_data = gpu_reshaped.abs().to(torch.float32)
+    gpu_all_data = gpu_batched.abs().to(torch.float32)
     reshape_time = time.perf_counter() - reshape_start
-    print(f"  Reshape & convert: {reshape_time:.4f} s")
+    print(f"  Convert to float32: {reshape_time:.4f} s")
 
     # Forward pass (data already on GPU)
     for i in range(0, n_samples, BATCH_SIZE):
@@ -299,7 +301,7 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
     # Clean cache after benchmark completion
     clean_cache()
 
-    return total_time, gpu_reshaped
+    return total_time, gpu_batched
 
 
 # -----------------------------------------------------------
@@ -325,13 +327,16 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
     dlpack_time = time.perf_counter() - dlpack_start
     print(f"  DLPack conversion: {dlpack_time:.4f} s")
 
+    # Tensor is already 2D [n_samples, state_len] from to_dlpack()
     state_len = 1 << n_qubits
+    assert gpu_batched.shape == (n_samples, state_len), (
+        f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
+    )
 
     reshape_start = time.perf_counter()
-    gpu_reshaped = gpu_batched.view(n_samples, state_len)
-    gpu_all_data = gpu_reshaped.abs().to(torch.float32)
+    gpu_all_data = gpu_batched.abs().to(torch.float32)
     reshape_time = time.perf_counter() - reshape_start
-    print(f"  Reshape & convert: {reshape_time:.4f} s")
+    print(f"  Convert to float32: {reshape_time:.4f} s")
 
     for i in range(0, n_samples, BATCH_SIZE):
         batch = gpu_all_data[i : i + BATCH_SIZE]
@@ -344,7 +349,7 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
     # Clean cache after benchmark completion
     clean_cache()
 
-    return total_time, gpu_reshaped
+    return total_time, gpu_batched
 
 
 def compare_states(name_a, states_a, name_b, states_b):
diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index 883d19b37..5a2f7ecea 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -120,9 +120,18 @@ impl GpuStateVector {
     /// Freed by DLPack deleter when PyTorch releases tensor.
     /// Do not free manually.
     pub fn to_dlpack(&self) -> *mut DLManagedTensor {
-        // Allocate shape/strides on heap (freed by deleter)
-        let shape = vec![self.size_elements as i64];
-        let strides = vec![1i64];
+        let (shape, strides, ndim) = if let Some(num_samples) = 
self.num_samples {
+            // Batch: 2D shape [num_samples, state_len_per_sample], row-major 
strides
+            let state_len_per_sample = self.size_elements / num_samples;
+            let shape = vec![num_samples as i64, state_len_per_sample as i64];
+            let strides = vec![state_len_per_sample as i64, 1i64]; // Strides 
in elements, not bytes
+            (shape, strides, 2)
+        } else {
+            // Single state: 1D shape [size_elements]
+            let shape = vec![self.size_elements as i64];
+            let strides = vec![1i64];
+            (shape, strides, 1)
+        };
 
         // Transfer ownership to DLPack deleter
         let shape_ptr = Box::into_raw(shape.into_boxed_slice()) as *mut i64;
@@ -142,7 +151,7 @@ impl GpuStateVector {
                 device_type: DLDeviceType::kDLCUDA,
                 device_id: 0,
             },
-            ndim: 1,
+            ndim,
             dtype: DLDataType {
                 code: DL_COMPLEX,
                 bits: dtype_bits,
diff --git a/qdp/qdp-core/src/gpu/memory.rs b/qdp/qdp-core/src/gpu/memory.rs
index 26e7b1383..240ec54cf 100644
--- a/qdp/qdp-core/src/gpu/memory.rs
+++ b/qdp/qdp-core/src/gpu/memory.rs
@@ -190,6 +190,8 @@ pub struct GpuStateVector {
     pub(crate) buffer: Arc<BufferStorage>,
     pub num_qubits: usize,
     pub size_elements: usize,
+    /// Number of samples in batch. None for single state, Some(n) for batch.
+    pub(crate) num_samples: Option<usize>,
 }
 
 // Safety: CudaSlice and Arc are both Send + Sync
@@ -229,6 +231,7 @@ impl GpuStateVector {
                 buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
                 num_qubits: qubits,
                 size_elements: _size_elements,
+                num_samples: None,
             })
         }
 
@@ -300,6 +303,7 @@ impl GpuStateVector {
                 buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
                 num_qubits: qubits,
                 size_elements: total_elements,
+                num_samples: Some(num_samples),
             })
         }
 
@@ -364,6 +368,7 @@ impl GpuStateVector {
                         buffer: Arc::new(BufferStorage::F32(GpuBufferRaw { 
slice })),
                         num_qubits: self.num_qubits,
                         size_elements: self.size_elements,
+                        num_samples: self.num_samples, // Preserve batch 
information
                     })
                 }
 
diff --git a/qdp/qdp-core/src/preprocessing.rs 
b/qdp/qdp-core/src/preprocessing.rs
index 0d8e70148..43577a8eb 100644
--- a/qdp/qdp-core/src/preprocessing.rs
+++ b/qdp/qdp-core/src/preprocessing.rs
@@ -84,6 +84,12 @@ impl Preprocessor {
         sample_size: usize,
         num_qubits: usize,
     ) -> Result<()> {
+        if num_samples == 0 {
+            return Err(MahoutError::InvalidInput(
+                "num_samples must be greater than 0".to_string()
+            ));
+        }
+
         if batch_data.len() != num_samples * sample_size {
             return Err(MahoutError::InvalidInput(
                 format!("Batch data length {} doesn't match num_samples {} * 
sample_size {}",
diff --git a/qdp/qdp-core/tests/api_workflow.rs 
b/qdp/qdp-core/tests/api_workflow.rs
index a1e97e31a..7e9de3081 100644
--- a/qdp/qdp-core/tests/api_workflow.rs
+++ b/qdp/qdp-core/tests/api_workflow.rs
@@ -107,3 +107,95 @@ fn test_amplitude_encoding_async_pipeline() {
         println!("PASS: Memory freed successfully");
     }
 }
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_batch_dlpack_2d_shape() {
+    println!("Testing batch DLPack 2D shape...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    // Create batch data: 3 samples, each with 4 elements (2 qubits)
+    let num_samples = 3;
+    let num_qubits = 2;
+    let sample_size = 4;
+    let batch_data: Vec<f64> = (0..num_samples * sample_size)
+        .map(|i| (i as f64) / 10.0)
+        .collect();
+
+    let result = engine.encode_batch(&batch_data, num_samples, sample_size, 
num_qubits, "amplitude");
+    assert!(result.is_ok(), "Batch encoding should succeed");
+
+    let dlpack_ptr = result.unwrap();
+    assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+    unsafe {
+        let managed = &*dlpack_ptr;
+        let tensor = &managed.dl_tensor;
+
+        // Verify 2D shape for batch tensor
+        assert_eq!(tensor.ndim, 2, "Batch tensor should be 2D");
+
+        let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim 
as usize);
+        assert_eq!(shape_slice[0], num_samples as i64, "First dimension should 
be num_samples");
+        assert_eq!(shape_slice[1], (1 << num_qubits) as i64, "Second dimension 
should be 2^num_qubits");
+
+        let strides_slice = std::slice::from_raw_parts(tensor.strides, 
tensor.ndim as usize);
+        let state_len = 1 << num_qubits;
+        assert_eq!(strides_slice[0], state_len as i64, "Stride for first 
dimension should be state_len");
+        assert_eq!(strides_slice[1], 1, "Stride for second dimension should be 
1");
+
+        println!("PASS: Batch DLPack tensor has correct 2D shape: [{}, {}]", 
shape_slice[0], shape_slice[1]);
+        println!("PASS: Strides are correct: [{}, {}]", strides_slice[0], 
strides_slice[1]);
+
+        // Free memory
+        if let Some(deleter) = managed.deleter {
+            deleter(dlpack_ptr);
+        }
+    }
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_single_encode_still_1d() {
+    println!("Testing single encode still returns 1D shape...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    let data = common::create_test_data(16);
+    let result = engine.encode(&data, 4, "amplitude");
+    assert!(result.is_ok(), "Encoding should succeed");
+
+    let dlpack_ptr = result.unwrap();
+    assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+    unsafe {
+        let managed = &*dlpack_ptr;
+        let tensor = &managed.dl_tensor;
+
+        // Verify 1D shape for single encode (backward compatibility)
+        assert_eq!(tensor.ndim, 1, "Single encode should still be 1D");
+
+        let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim 
as usize);
+        assert_eq!(shape_slice[0], 16, "Single encode shape should be [2^4]");
+
+        println!("PASS: Single encode still returns 1D shape: [{}]", 
shape_slice[0]);
+
+        // Free memory
+        if let Some(deleter) = managed.deleter {
+            deleter(dlpack_ptr);
+        }
+    }
+}
diff --git a/qdp/qdp-core/tests/memory_safety.rs 
b/qdp/qdp-core/tests/memory_safety.rs
index 833190c48..37f45478a 100644
--- a/qdp/qdp-core/tests/memory_safety.rs
+++ b/qdp/qdp-core/tests/memory_safety.rs
@@ -114,7 +114,7 @@ fn test_dlpack_tensor_metadata() {
         assert_eq!(stride, 1, "Stride for 1D contiguous array should be 1");
 
         assert_eq!(tensor.dtype.code, 5, "Should be complex type (code=5)");
-        assert_eq!(tensor.dtype.bits, 128, "Should be 128 bits (2x64-bit 
floats)");
+        assert_eq!(tensor.dtype.bits, 64, "Should be 64 bits (2x32-bit floats, 
Float32 default)");
 
         println!("PASS: DLPack metadata verified");
         println!("  ndim: {}", tensor.ndim);
diff --git a/qdp/qdp-core/tests/validation.rs b/qdp/qdp-core/tests/validation.rs
index cc12a995a..6fc591e53 100644
--- a/qdp/qdp-core/tests/validation.rs
+++ b/qdp/qdp-core/tests/validation.rs
@@ -119,6 +119,30 @@ fn test_input_validation_max_qubits() {
     }
 }
 
+#[test]
+#[cfg(target_os = "linux")]
+fn test_input_validation_batch_zero_samples() {
+    println!("Testing zero num_samples rejection...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => return,
+    };
+
+    let batch_data = vec![1.0, 2.0, 3.0, 4.0];
+    let result = engine.encode_batch(&batch_data, 0, 4, 2, "amplitude");
+    assert!(result.is_err(), "Should reject zero num_samples");
+
+    match result {
+        Err(MahoutError::InvalidInput(msg)) => {
+            assert!(msg.contains("num_samples must be greater than 0"),
+                    "Error should mention num_samples requirement");
+            println!("PASS: Correctly rejected zero num_samples: {}", msg);
+        }
+        _ => panic!("Expected InvalidInput error for zero num_samples"),
+    }
+}
+
 #[test]
 #[cfg(target_os = "linux")]
 fn test_empty_data() {

Reply via email to