This is an automated email from the ASF dual-hosted git repository.
jiekaichang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new d176ba52e [QDP] DLPack input extraction (#935)
d176ba52e is described below
commit d176ba52e3e987abf69765a298820ca436272af3
Author: Jie-Kai Chang <[email protected]>
AuthorDate: Sat Jan 31 21:14:58 2026 +0800
[QDP] DLPack input extraction (#935)
* DLPack input extraction
Signed-off-by: 400Ping <[email protected]>
* fix CI error
Signed-off-by: 400Ping <[email protected]>
* fix
Signed-off-by: 400Ping <[email protected]>
* update
Signed-off-by: 400Ping <[email protected]>
---------
Signed-off-by: 400Ping <[email protected]>
Signed-off-by: 400Ping <[email protected]>
---
qdp/qdp-core/src/dlpack.rs | 3 +
qdp/qdp-python/pyproject.toml | 1 +
qdp/qdp-python/src/lib.rs | 108 +++++++++++++++++++++----
qdp/qdp-python/tests/test_dlpack_validation.py | 100 +++++++++++++++++++++++
4 files changed, 197 insertions(+), 15 deletions(-)
diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index eb3e596be..1780fa404 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -105,7 +105,10 @@ pub unsafe fn synchronize_stream(_stream: *mut c_void) ->
Result<()> {
// DLPack C structures (matching dlpack/dlpack.h)
+/// Device type enum for DLPack. Eq/PartialEq used for validation (e.g.
device_type != kDLCUDA);
+/// Debug for diagnostics; Copy/Clone for FFI ergonomics when used in DLDevice.
#[repr(C)]
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[allow(non_camel_case_types)]
pub enum DLDeviceType {
kDLCPU = 1,
diff --git a/qdp/qdp-python/pyproject.toml b/qdp/qdp-python/pyproject.toml
index 909cd4a34..72f663a5a 100644
--- a/qdp/qdp-python/pyproject.toml
+++ b/qdp/qdp-python/pyproject.toml
@@ -15,6 +15,7 @@ classifiers = [
dynamic = ["version"]
[dependency-groups]
+dev = ["pytest", "torch>=2.2,<=2.9.0"]
benchmark = [
"numpy>=1.24,<2.0",
"pandas>=2.0",
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 1a6ebe735..fd655b2da 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -18,7 +18,7 @@ use numpy::{PyReadonlyArray1, PyReadonlyArray2,
PyUntypedArrayMethods};
use pyo3::exceptions::PyRuntimeError;
use pyo3::ffi;
use pyo3::prelude::*;
-use qdp_core::dlpack::DLManagedTensor;
+use qdp_core::dlpack::{DL_FLOAT, DLDeviceType, DLManagedTensor};
use qdp_core::{Precision, QdpEngine as CoreEngine};
use std::ffi::c_void;
@@ -409,46 +409,124 @@ fn extract_dlpack_tensor(_py: Python<'_>, tensor:
&Bound<'_, PyAny>) -> PyResult
// Note: PyTorch's __dlpack__ uses the default stream when called without
arguments
let capsule = tensor.call_method0("__dlpack__")?;
- // Extract the DLManagedTensor pointer from the capsule
const DLTENSOR_NAME: &[u8] = b"dltensor\0";
- unsafe {
+ // SAFETY: capsule is a valid PyCapsule from tensor.__dlpack__().
DLTENSOR_NAME is a
+ // null-terminated C string for the lifetime of the call. We only read the
capsule
+ // and call PyCapsule_IsValid / PyCapsule_GetPointer; we do not invalidate
the capsule.
+ let managed_ptr = unsafe {
let capsule_ptr = capsule.as_ptr();
- let managed_ptr =
- ffi::PyCapsule_GetPointer(capsule_ptr, DLTENSOR_NAME.as_ptr() as
*const i8)
- as *mut DLManagedTensor;
-
- if managed_ptr.is_null() {
+ if ffi::PyCapsule_IsValid(capsule_ptr, DLTENSOR_NAME.as_ptr() as
*const i8) == 0 {
+ return Err(PyRuntimeError::new_err(
+ "Invalid DLPack capsule (expected 'dltensor')",
+ ));
+ }
+ let ptr = ffi::PyCapsule_GetPointer(capsule_ptr,
DLTENSOR_NAME.as_ptr() as *const i8)
+ as *mut DLManagedTensor;
+ if ptr.is_null() {
return Err(PyRuntimeError::new_err(
"Failed to extract DLManagedTensor from PyCapsule",
));
}
+ ptr
+ };
+ // SAFETY: managed_ptr is non-null and was returned by
PyCapsule_GetPointer for a valid
+ // "dltensor" capsule, so it points to a valid DLManagedTensor. The
capsule (and thus
+ // the tensor) is held by the caller for the duration of this function. We
read fields
+ // and create slices from shape/strides only when non-null and ndim is
valid.
+ unsafe {
let dl_tensor = &(*managed_ptr).dl_tensor;
- // Extract data pointer with null check
if dl_tensor.data.is_null() {
return Err(PyRuntimeError::new_err(
"DLPack tensor has null data pointer",
));
}
- let data_ptr = dl_tensor.data as *const f64;
- // Extract shape
+ if dl_tensor.device.device_type != DLDeviceType::kDLCUDA {
+ return Err(PyRuntimeError::new_err(
+ "DLPack tensor must be on CUDA device",
+ ));
+ }
+
+ if dl_tensor.dtype.code != DL_FLOAT
+ || dl_tensor.dtype.bits != 64
+ || dl_tensor.dtype.lanes != 1
+ {
+ return Err(PyRuntimeError::new_err(format!(
+ "DLPack tensor must be float64 (code={}, bits={}, lanes={})",
+ dl_tensor.dtype.code, dl_tensor.dtype.bits,
dl_tensor.dtype.lanes
+ )));
+ }
+
+ if !dl_tensor
+ .byte_offset
+ .is_multiple_of(std::mem::size_of::<f64>() as u64)
+ {
+ return Err(PyRuntimeError::new_err(
+ "DLPack tensor byte_offset is not aligned for float64",
+ ));
+ }
+
+ let data_ptr =
+ (dl_tensor.data as *const u8).add(dl_tensor.byte_offset as usize)
as *const f64;
+
let ndim = dl_tensor.ndim as usize;
+ // SAFETY: shape pointer is valid for ndim elements when non-null
(DLPack contract).
let shape = if ndim > 0 && !dl_tensor.shape.is_null() {
std::slice::from_raw_parts(dl_tensor.shape, ndim).to_vec()
} else {
vec![]
};
- // Extract device_id
+ if ndim == 0 || shape.is_empty() {
+ return Err(PyRuntimeError::new_err(
+ "DLPack tensor must have at least 1 dimension",
+ ));
+ }
+
+ if !dl_tensor.strides.is_null() {
+ // SAFETY: strides pointer is valid for ndim elements (DLPack
contract).
+ let strides = std::slice::from_raw_parts(dl_tensor.strides, ndim);
+ match ndim {
+ 1 => {
+ let expected = 1_i64;
+ if strides[0] != expected {
+ return Err(PyRuntimeError::new_err(format!(
+ "DLPack tensor must be contiguous: stride[0]={},
expected {}",
+ strides[0], expected
+ )));
+ }
+ }
+ 2 => {
+ if shape.len() < 2 {
+ return Err(PyRuntimeError::new_err(
+ "DLPack tensor must be contiguous (shape len < 2)",
+ ));
+ }
+ let expected_stride_1 = 1_i64;
+ let expected_stride_0 = shape[1];
+ if strides[1] != expected_stride_1 || strides[0] !=
expected_stride_0 {
+ return Err(PyRuntimeError::new_err(format!(
+ "DLPack tensor must be contiguous: strides=[{},
{}], expected [{}, {}] (expected[1]=shape[1])",
+ strides[0], strides[1], expected_stride_0,
expected_stride_1
+ )));
+ }
+ }
+ _ => {
+ return Err(PyRuntimeError::new_err(
+ "DLPack tensor must be 1D or 2D for encoding",
+ ));
+ }
+ }
+ }
+
let device_id = dl_tensor.device.device_id;
- // Rename the capsule to "used_dltensor" as per DLPack protocol
- // This prevents PyTorch from trying to delete it when the capsule is
garbage collected
const USED_DLTENSOR_NAME: &[u8] = b"used_dltensor\0";
- ffi::PyCapsule_SetName(capsule_ptr, USED_DLTENSOR_NAME.as_ptr() as
*const i8);
+ // SAFETY: capsule is the same PyCapsule we used above; renaming is
allowed and does not free it.
+ ffi::PyCapsule_SetName(capsule.as_ptr(), USED_DLTENSOR_NAME.as_ptr()
as *const i8);
Ok(DLPackTensorInfo {
managed_ptr,
diff --git a/qdp/qdp-python/tests/test_dlpack_validation.py
b/qdp/qdp-python/tests/test_dlpack_validation.py
new file mode 100644
index 000000000..84335c94a
--- /dev/null
+++ b/qdp/qdp-python/tests/test_dlpack_validation.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for DLPack validation in encode_from_pytorch
(extract_dlpack_tensor)."""
+
+import pytest
+
+try:
+ import torch
+ from _qdp import QdpEngine
+except ImportError:
+ torch = None
+ QdpEngine = None
+
+
+def _cuda_available():
+ return torch is not None and torch.cuda.is_available()
+
+
+def _engine():
+ if QdpEngine is None:
+ pytest.skip("_qdp not built")
+ e = QdpEngine(0)
+ return e
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_dtype_validation_float32_rejected():
+ """DLPack tensor must be float64; float32 CUDA tensor should fail with
clear message."""
+ engine = _engine()
+ # 1D float32 CUDA tensor (contiguous)
+ t = torch.randn(4, dtype=torch.float32, device="cuda")
+ with pytest.raises(RuntimeError) as exc_info:
+ engine.encode(t, num_qubits=2, encoding_method="amplitude")
+ msg = str(exc_info.value).lower()
+ assert "float64" in msg
+ assert "code=" in msg or "bits=" in msg or "lanes=" in msg
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_stride_1d_non_contiguous_rejected():
+ """Non-contiguous 1D CUDA tensor (stride != 1) should fail with actual vs
expected."""
+ engine = _engine()
+ # Slice so stride is 2: shape (2,), stride (2,)
+ t = torch.randn(4, dtype=torch.float64, device="cuda")[::2]
+ assert t.stride(0) != 1
+ with pytest.raises(RuntimeError) as exc_info:
+ engine.encode(t, num_qubits=1, encoding_method="amplitude")
+ msg = str(exc_info.value)
+ assert "contiguous" in msg.lower()
+ assert "stride[0]=" in msg
+ assert "expected 1" in msg or "expected 1 " in msg
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_stride_2d_non_contiguous_rejected():
+ """Non-contiguous 2D CUDA tensor should fail with actual vs expected
strides."""
+ engine = _engine()
+ # (4, 2) with strides (3, 2) -> not C-contiguous; expected for (4,2) is
(2, 1)
+ t = torch.randn(4, 3, dtype=torch.float64, device="cuda")[:, ::2]
+ assert t.dim() == 2 and t.shape == (4, 2)
+ # Strides should be (3, 2) not (2, 1)
+ assert t.stride(0) == 3 and t.stride(1) == 2
+ with pytest.raises(RuntimeError) as exc_info:
+ engine.encode(t, num_qubits=1, encoding_method="amplitude")
+ msg = str(exc_info.value)
+ assert "contiguous" in msg.lower()
+ assert "strides=" in msg
+ assert "expected" in msg
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_valid_cuda_float64_1d_succeeds():
+ """Valid 1D float64 contiguous CUDA tensor should encode successfully."""
+ engine = _engine()
+ t = torch.randn(4, dtype=torch.float64, device="cuda")
+ result = engine.encode(t, num_qubits=2, encoding_method="amplitude")
+ assert result is not None
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_valid_cuda_float64_2d_succeeds():
+ """Valid 2D float64 contiguous CUDA tensor should encode successfully."""
+ engine = _engine()
+ t = torch.randn(3, 4, dtype=torch.float64, device="cuda")
+ result = engine.encode(t, num_qubits=2, encoding_method="amplitude")
+ assert result is not None