This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new d4a678d60 [QDP] Add PyTorch shape validation (#810)
d4a678d60 is described below

commit d4a678d60a1e94b980e6b24c1845b5aa8557abab
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Mon Jan 12 23:59:33 2026 +0800

    [QDP] Add PyTorch shape validation (#810)
---
 qdp/qdp-python/src/lib.rs | 66 +++++++++++++++++++++++++++++++++++++----------
 1 file changed, 52 insertions(+), 14 deletions(-)

diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index ba9bbd3ac..686795c5e 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -263,7 +263,7 @@ impl QdpEngine {
 
             match ndim {
                 1 => {
-                    // 1D array: single sample encoding
+                    // 1D array: single sample encoding (zero-copy if already 
contiguous)
                     let array_1d = 
data.extract::<PyReadonlyArray1<f64>>().map_err(|_| {
                         PyRuntimeError::new_err(
                             "Failed to extract 1D NumPy array. Ensure dtype is 
float64.",
@@ -282,7 +282,7 @@ impl QdpEngine {
                     });
                 }
                 2 => {
-                    // 2D array: batch encoding
+                    // 2D array: batch encoding (zero-copy if already 
contiguous)
                     let array_2d = 
data.extract::<PyReadonlyArray2<f64>>().map_err(|_| {
                         PyRuntimeError::new_err(
                             "Failed to extract 2D NumPy array. Ensure dtype is 
float64.",
@@ -322,18 +322,56 @@ impl QdpEngine {
         // Check if it's a PyTorch tensor
         if is_pytorch_tensor(data)? {
             validate_tensor(data)?;
-            let vec_data: Vec<f64> = data
-                .call_method0("flatten")?
-                .call_method0("tolist")?
-                .extract()?;
-            let ptr = self
-                .engine
-                .encode(&vec_data, num_qubits, encoding_method)
-                .map_err(|e| PyRuntimeError::new_err(format!("Encoding failed: 
{}", e)))?;
-            return Ok(QuantumTensor {
-                ptr,
-                consumed: false,
-            });
+            // NOTE(perf): `tolist()` + `extract()` makes extra copies (Tensor 
-> Python list -> Vec).
+            // TODO: Follow-up PR can use `numpy()`/buffer protocol (and 
possibly pinned host memory)
+            // to reduce copy overhead.
+            let ndim: usize = data.call_method0("dim")?.extract()?;
+
+            match ndim {
+                1 => {
+                    // 1D tensor: single sample encoding
+                    let vec_data: Vec<f64> = 
data.call_method0("tolist")?.extract()?;
+                    let ptr = self
+                        .engine
+                        .encode(&vec_data, num_qubits, encoding_method)
+                        .map_err(|e| PyRuntimeError::new_err(format!("Encoding 
failed: {}", e)))?;
+                    return Ok(QuantumTensor {
+                        ptr,
+                        consumed: false,
+                    });
+                }
+                2 => {
+                    // 2D tensor: batch encoding
+                    let shape: Vec<i64> = data.getattr("shape")?.extract()?;
+                    let num_samples = shape[0] as usize;
+                    let sample_size = shape[1] as usize;
+                    let vec_data: Vec<f64> = data
+                        .call_method0("flatten")?
+                        .call_method0("tolist")?
+                        .extract()?;
+                    let ptr = self
+                        .engine
+                        .encode_batch(
+                            &vec_data,
+                            num_samples,
+                            sample_size,
+                            num_qubits,
+                            encoding_method,
+                        )
+                        .map_err(|e| PyRuntimeError::new_err(format!("Encoding 
failed: {}", e)))?;
+                    return Ok(QuantumTensor {
+                        ptr,
+                        consumed: false,
+                    });
+                }
+                _ => {
+                    return Err(PyRuntimeError::new_err(format!(
+                        "Unsupported tensor shape: {}D. Expected 1D tensor for 
single sample \
+                         encoding or 2D tensor (batch_size, features) for 
batch encoding.",
+                        ndim
+                    )));
+                }
+            }
         }
 
         // Fallback: try to extract as Vec<f64> (Python list)

Reply via email to