This is an automated email from the ASF dual-hosted git repository.
richhuang pushed a commit to branch dev-qdp
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/dev-qdp by this push:
new 6bb83829f MAHOUT-725: [QDP] PyTorch Tensor Detection and CPU Path
(#763)
6bb83829f is described below
commit 6bb83829f786ae49a765c9f14479fd229eebe683
Author: Cheyu Wu <[email protected]>
AuthorDate: Sun Jan 4 17:07:13 2026 +0800
MAHOUT-725: [QDP] PyTorch Tensor Detection and CPU Path (#763)
* feat: [QDP] PyTorch Tensor Detection and CPU Path
Signed-off-by: Cheyu Wu <[email protected]>
* style: fix linter issue
Signed-off-by: Cheyu Wu <[email protected]>
* style: linter err
Signed-off-by: Cheyu Wu <[email protected]>
* doc: add comment for followup pr
Signed-off-by: Cheyu Wu <[email protected]>
* revert: ipynb linter fix
Signed-off-by: Cheyu Wu <[email protected]>
---------
Signed-off-by: Cheyu Wu <[email protected]>
---
qdp/qdp-python/src/lib.rs | 67 +++++++++++++++++++++++++++++++++++
qdp/qdp-python/tests/test_bindings.py | 52 +++++++++++++++++++++++++--
2 files changed, 116 insertions(+), 3 deletions(-)
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index df68be1bc..d06ce3d5d 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -139,6 +139,37 @@ impl Drop for QuantumTensor {
unsafe impl Send for QuantumTensor {}
unsafe impl Sync for QuantumTensor {}
+/// Helper to detect PyTorch tensor
+fn is_pytorch_tensor(obj: &Bound<'_, PyAny>) -> PyResult<bool> {
+ let type_obj = obj.get_type();
+ let name = type_obj.name()?;
+ if name != "Tensor" {
+ return Ok(false);
+ }
+ let module = type_obj.module()?;
+ let module_name = module.to_str()?;
+ Ok(module_name == "torch")
+}
+
+/// Helper to validate tensor
+fn validate_tensor(tensor: &Bound<'_, PyAny>) -> PyResult<()> {
+ if !is_pytorch_tensor(tensor)? {
+ return Err(PyRuntimeError::new_err("Object is not a PyTorch Tensor"));
+ }
+
+ let device = tensor.getattr("device")?;
+ let device_type: String = device.getattr("type")?.extract()?;
+
+ if device_type != "cpu" {
+ return Err(PyRuntimeError::new_err(format!(
+ "Only CPU tensors are currently supported for this path. Got
device: {}",
+ device_type
+ )));
+ }
+
+ Ok(())
+}
+
/// PyO3 wrapper for QdpEngine
///
/// Provides Python bindings for GPU-accelerated quantum state encoding.
@@ -215,6 +246,42 @@ impl QdpEngine {
})
}
+ /// Encode from PyTorch Tensor
+ ///
+ /// Args:
+ /// tensor: PyTorch Tensor (must be on CPU)
+ /// num_qubits: Number of qubits for encoding
+ /// encoding_method: Encoding strategy
+ ///
+ /// Returns:
+ /// QuantumTensor: DLPack-compatible tensor
+ fn encode_tensor(
+ &self,
+ tensor: &Bound<'_, PyAny>,
+ num_qubits: usize,
+ encoding_method: &str,
+ ) -> PyResult<QuantumTensor> {
+ validate_tensor(tensor)?;
+
+ // NOTE(perf): `tolist()` + `extract()` makes extra copies (Tensor ->
Python list -> Vec).
+ // TODO: follow-up PR can use `numpy()`/buffer protocol (and possibly
pinned host memory)
+ // to reduce copy overhead.
+ let data: Vec<f64> = tensor
+ .call_method0("flatten")?
+ .call_method0("tolist")?
+ .extract()?;
+
+ let ptr = self
+ .engine
+ .encode(&data, num_qubits, encoding_method)
+ .map_err(|e| PyRuntimeError::new_err(format!("Encoding failed:
{}", e)))?;
+
+ Ok(QuantumTensor {
+ ptr,
+ consumed: false,
+ })
+ }
+
/// Encode from Parquet file
///
/// Args:
diff --git a/qdp/qdp-python/tests/test_bindings.py
b/qdp/qdp-python/tests/test_bindings.py
index 7808abc8c..ea23aceb7 100644
--- a/qdp/qdp-python/tests/test_bindings.py
+++ b/qdp/qdp-python/tests/test_bindings.py
@@ -77,9 +77,10 @@ def test_dlpack_device_id_non_zero():
qtensor = engine.encode(data, 2, "amplitude")
device_info = qtensor.__dlpack_device__()
- assert device_info == (2, device_id), (
- f"Expected (2, {device_id}) for CUDA device {device_id}"
- )
+ assert device_info == (
+ 2,
+ device_id,
+ ), f"Expected (2, {device_id}) for CUDA device {device_id}"
# Verify PyTorch integration works with non-zero device_id
torch_tensor = torch.from_dlpack(qtensor)
@@ -143,3 +144,48 @@ def test_pytorch_precision_float64():
torch_tensor = torch.from_dlpack(qtensor)
assert torch_tensor.dtype == torch.complex128
+
+
[email protected]
+def test_encode_tensor_cpu():
+ """Test encoding from CPU PyTorch tensor."""
+ pytest.importorskip("torch")
+ import torch
+ from mahout_qdp import QdpEngine
+
+ if not torch.cuda.is_available():
+ pytest.skip("GPU required for QdpEngine")
+
+ engine = QdpEngine(0)
+ data = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float64)
+ qtensor = engine.encode_tensor(data, 2, "amplitude")
+
+ # Verify result
+ torch_tensor = torch.from_dlpack(qtensor)
+ assert torch_tensor.is_cuda
+ assert torch_tensor.shape == (1, 4)
+
+
[email protected]
+def test_encode_tensor_errors():
+ """Test error handling for encode_tensor."""
+ pytest.importorskip("torch")
+ import torch
+ from mahout_qdp import QdpEngine
+
+ if not torch.cuda.is_available():
+ pytest.skip("GPU required for QdpEngine")
+
+ engine = QdpEngine(0)
+
+ # Test non-tensor input
+ with pytest.raises(RuntimeError, match="Object is not a PyTorch Tensor"):
+ engine.encode_tensor([1.0, 2.0], 1, "amplitude")
+
+ # Test GPU tensor input (should fail as only CPU is supported for this
path)
+ if torch.cuda.is_available():
+ gpu_tensor = torch.tensor([1.0, 2.0], device="cuda:0")
+ with pytest.raises(
+ RuntimeError, match="Only CPU tensors are currently supported"
+ ):
+ engine.encode_tensor(gpu_tensor, 1, "amplitude")