This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git

commit 4b34a6363852935b74e7f8f7ee94d691ae1619d5
Author: Cheyu Wu <[email protected]>
AuthorDate: Sun Jan 4 17:07:13 2026 +0800

    MAHOUT-725: [QDP] PyTorch Tensor Detection and CPU Path (#763)
    
    * feat: [QDP] PyTorch Tensor Detection and CPU Path
    
    Signed-off-by: Cheyu Wu <[email protected]>
    
    * style: fix linter issue
    
    Signed-off-by: Cheyu Wu <[email protected]>
    
    * style: linter err
    
    Signed-off-by: Cheyu Wu <[email protected]>
    
    * doc: add comment for followup pr
    
    Signed-off-by: Cheyu Wu <[email protected]>
    
    * revert: ipynb linter fix
    
    Signed-off-by: Cheyu Wu <[email protected]>
    
    ---------
    
    Signed-off-by: Cheyu Wu <[email protected]>
---
 qdp/qdp-python/src/lib.rs             | 67 +++++++++++++++++++++++++++++++++++
 qdp/qdp-python/tests/test_bindings.py | 52 +++++++++++++++++++++++++--
 2 files changed, 116 insertions(+), 3 deletions(-)

diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index df68be1bc..d06ce3d5d 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -139,6 +139,37 @@ impl Drop for QuantumTensor {
 unsafe impl Send for QuantumTensor {}
 unsafe impl Sync for QuantumTensor {}
 
+/// Helper to detect PyTorch tensor
+fn is_pytorch_tensor(obj: &Bound<'_, PyAny>) -> PyResult<bool> {
+    let type_obj = obj.get_type();
+    let name = type_obj.name()?;
+    if name != "Tensor" {
+        return Ok(false);
+    }
+    let module = type_obj.module()?;
+    let module_name = module.to_str()?;
+    Ok(module_name == "torch")
+}
+
+/// Helper to validate tensor
+fn validate_tensor(tensor: &Bound<'_, PyAny>) -> PyResult<()> {
+    if !is_pytorch_tensor(tensor)? {
+        return Err(PyRuntimeError::new_err("Object is not a PyTorch Tensor"));
+    }
+
+    let device = tensor.getattr("device")?;
+    let device_type: String = device.getattr("type")?.extract()?;
+
+    if device_type != "cpu" {
+        return Err(PyRuntimeError::new_err(format!(
+            "Only CPU tensors are currently supported for this path. Got 
device: {}",
+            device_type
+        )));
+    }
+
+    Ok(())
+}
+
 /// PyO3 wrapper for QdpEngine
 ///
 /// Provides Python bindings for GPU-accelerated quantum state encoding.
@@ -215,6 +246,42 @@ impl QdpEngine {
         })
     }
 
+    /// Encode from PyTorch Tensor
+    ///
+    /// Args:
+    ///     tensor: PyTorch Tensor (must be on CPU)
+    ///     num_qubits: Number of qubits for encoding
+    ///     encoding_method: Encoding strategy
+    ///
+    /// Returns:
+    ///     QuantumTensor: DLPack-compatible tensor
+    fn encode_tensor(
+        &self,
+        tensor: &Bound<'_, PyAny>,
+        num_qubits: usize,
+        encoding_method: &str,
+    ) -> PyResult<QuantumTensor> {
+        validate_tensor(tensor)?;
+
+        // NOTE(perf): `tolist()` + `extract()` makes extra copies (Tensor -> 
Python list -> Vec).
+        // TODO: follow-up PR can use `numpy()`/buffer protocol (and possibly 
pinned host memory)
+        // to reduce copy overhead.
+        let data: Vec<f64> = tensor
+            .call_method0("flatten")?
+            .call_method0("tolist")?
+            .extract()?;
+
+        let ptr = self
+            .engine
+            .encode(&data, num_qubits, encoding_method)
+            .map_err(|e| PyRuntimeError::new_err(format!("Encoding failed: 
{}", e)))?;
+
+        Ok(QuantumTensor {
+            ptr,
+            consumed: false,
+        })
+    }
+
     /// Encode from Parquet file
     ///
     /// Args:
diff --git a/qdp/qdp-python/tests/test_bindings.py 
b/qdp/qdp-python/tests/test_bindings.py
index 7808abc8c..ea23aceb7 100644
--- a/qdp/qdp-python/tests/test_bindings.py
+++ b/qdp/qdp-python/tests/test_bindings.py
@@ -77,9 +77,10 @@ def test_dlpack_device_id_non_zero():
     qtensor = engine.encode(data, 2, "amplitude")
 
     device_info = qtensor.__dlpack_device__()
-    assert device_info == (2, device_id), (
-        f"Expected (2, {device_id}) for CUDA device {device_id}"
-    )
+    assert device_info == (
+        2,
+        device_id,
+    ), f"Expected (2, {device_id}) for CUDA device {device_id}"
 
     # Verify PyTorch integration works with non-zero device_id
     torch_tensor = torch.from_dlpack(qtensor)
@@ -143,3 +144,48 @@ def test_pytorch_precision_float64():
 
     torch_tensor = torch.from_dlpack(qtensor)
     assert torch_tensor.dtype == torch.complex128
+
+
[email protected]
+def test_encode_tensor_cpu():
+    """Test encoding from CPU PyTorch tensor."""
+    pytest.importorskip("torch")
+    import torch
+    from mahout_qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+    data = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
+    qtensor = engine.encode_tensor(data, 2, "amplitude")
+
+    # Verify result
+    torch_tensor = torch.from_dlpack(qtensor)
+    assert torch_tensor.is_cuda
+    assert torch_tensor.shape == (1, 4)
+
+
[email protected]
+def test_encode_tensor_errors():
+    """Test error handling for encode_tensor."""
+    pytest.importorskip("torch")
+    import torch
+    from mahout_qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Test non-tensor input
+    with pytest.raises(RuntimeError, match="Object is not a PyTorch Tensor"):
+        engine.encode_tensor([1.0, 2.0], 1, "amplitude")
+
+    # Test GPU tensor input (should fail as only CPU is supported for this 
path)
+    if torch.cuda.is_available():
+        gpu_tensor = torch.tensor([1.0, 2.0], device="cuda:0")
+        with pytest.raises(
+            RuntimeError, match="Only CPU tensors are currently supported"
+        ):
+            engine.encode_tensor(gpu_tensor, 1, "amplitude")

Reply via email to