This is an automated email from the ASF dual-hosted git repository.

jiekaichang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new d176ba52e [QDP] DLPack input extraction (#935)
d176ba52e is described below

commit d176ba52e3e987abf69765a298820ca436272af3
Author: Jie-Kai Chang <[email protected]>
AuthorDate: Sat Jan 31 21:14:58 2026 +0800

    [QDP] DLPack input extraction (#935)
    
    * DLPack input extraction
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix CI error
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix
    
    Signed-off-by: 400Ping <[email protected]>
    
    * update
    
    Signed-off-by: 400Ping <[email protected]>
    
    ---------
    
    Signed-off-by: 400Ping <[email protected]>
    Signed-off-by: 400Ping <[email protected]>
---
 qdp/qdp-core/src/dlpack.rs                     |   3 +
 qdp/qdp-python/pyproject.toml                  |   1 +
 qdp/qdp-python/src/lib.rs                      | 108 +++++++++++++++++++++----
 qdp/qdp-python/tests/test_dlpack_validation.py | 100 +++++++++++++++++++++++
 4 files changed, 197 insertions(+), 15 deletions(-)

diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index eb3e596be..1780fa404 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -105,7 +105,10 @@ pub unsafe fn synchronize_stream(_stream: *mut c_void) -> 
Result<()> {
 
 // DLPack C structures (matching dlpack/dlpack.h)
 
+/// Device type enum for DLPack. Eq/PartialEq used for validation (e.g. 
device_type != kDLCUDA);
+/// Debug for diagnostics; Copy/Clone for FFI ergonomics when used in DLDevice.
 #[repr(C)]
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
 #[allow(non_camel_case_types)]
 pub enum DLDeviceType {
     kDLCPU = 1,
diff --git a/qdp/qdp-python/pyproject.toml b/qdp/qdp-python/pyproject.toml
index 909cd4a34..72f663a5a 100644
--- a/qdp/qdp-python/pyproject.toml
+++ b/qdp/qdp-python/pyproject.toml
@@ -15,6 +15,7 @@ classifiers = [
 dynamic = ["version"]
 
 [dependency-groups]
+dev = ["pytest", "torch>=2.2,<=2.9.0"]
 benchmark = [
     "numpy>=1.24,<2.0",
     "pandas>=2.0",
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 1a6ebe735..fd655b2da 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -18,7 +18,7 @@ use numpy::{PyReadonlyArray1, PyReadonlyArray2, 
PyUntypedArrayMethods};
 use pyo3::exceptions::PyRuntimeError;
 use pyo3::ffi;
 use pyo3::prelude::*;
-use qdp_core::dlpack::DLManagedTensor;
+use qdp_core::dlpack::{DL_FLOAT, DLDeviceType, DLManagedTensor};
 use qdp_core::{Precision, QdpEngine as CoreEngine};
 use std::ffi::c_void;
 
@@ -409,46 +409,124 @@ fn extract_dlpack_tensor(_py: Python<'_>, tensor: 
&Bound<'_, PyAny>) -> PyResult
     // Note: PyTorch's __dlpack__ uses the default stream when called without 
arguments
     let capsule = tensor.call_method0("__dlpack__")?;
 
-    // Extract the DLManagedTensor pointer from the capsule
     const DLTENSOR_NAME: &[u8] = b"dltensor\0";
 
-    unsafe {
+    // SAFETY: capsule is a valid PyCapsule from tensor.__dlpack__(). 
DLTENSOR_NAME is a
+    // null-terminated C string for the lifetime of the call. We only read the 
capsule
+    // and call PyCapsule_IsValid / PyCapsule_GetPointer; we do not invalidate 
the capsule.
+    let managed_ptr = unsafe {
         let capsule_ptr = capsule.as_ptr();
-        let managed_ptr =
-            ffi::PyCapsule_GetPointer(capsule_ptr, DLTENSOR_NAME.as_ptr() as 
*const i8)
-                as *mut DLManagedTensor;
-
-        if managed_ptr.is_null() {
+        if ffi::PyCapsule_IsValid(capsule_ptr, DLTENSOR_NAME.as_ptr() as 
*const i8) == 0 {
+            return Err(PyRuntimeError::new_err(
+                "Invalid DLPack capsule (expected 'dltensor')",
+            ));
+        }
+        let ptr = ffi::PyCapsule_GetPointer(capsule_ptr, 
DLTENSOR_NAME.as_ptr() as *const i8)
+            as *mut DLManagedTensor;
+        if ptr.is_null() {
             return Err(PyRuntimeError::new_err(
                 "Failed to extract DLManagedTensor from PyCapsule",
             ));
         }
+        ptr
+    };
 
+    // SAFETY: managed_ptr is non-null and was returned by 
PyCapsule_GetPointer for a valid
+    // "dltensor" capsule, so it points to a valid DLManagedTensor. The 
capsule (and thus
+    // the tensor) is held by the caller for the duration of this function. We 
read fields
+    // and create slices from shape/strides only when non-null and ndim is 
valid.
+    unsafe {
         let dl_tensor = &(*managed_ptr).dl_tensor;
 
-        // Extract data pointer with null check
         if dl_tensor.data.is_null() {
             return Err(PyRuntimeError::new_err(
                 "DLPack tensor has null data pointer",
             ));
         }
-        let data_ptr = dl_tensor.data as *const f64;
 
-        // Extract shape
+        if dl_tensor.device.device_type != DLDeviceType::kDLCUDA {
+            return Err(PyRuntimeError::new_err(
+                "DLPack tensor must be on CUDA device",
+            ));
+        }
+
+        if dl_tensor.dtype.code != DL_FLOAT
+            || dl_tensor.dtype.bits != 64
+            || dl_tensor.dtype.lanes != 1
+        {
+            return Err(PyRuntimeError::new_err(format!(
+                "DLPack tensor must be float64 (code={}, bits={}, lanes={})",
+                dl_tensor.dtype.code, dl_tensor.dtype.bits, 
dl_tensor.dtype.lanes
+            )));
+        }
+
+        if !dl_tensor
+            .byte_offset
+            .is_multiple_of(std::mem::size_of::<f64>() as u64)
+        {
+            return Err(PyRuntimeError::new_err(
+                "DLPack tensor byte_offset is not aligned for float64",
+            ));
+        }
+
+        let data_ptr =
+            (dl_tensor.data as *const u8).add(dl_tensor.byte_offset as usize) 
as *const f64;
+
         let ndim = dl_tensor.ndim as usize;
+        // SAFETY: shape pointer is valid for ndim elements when non-null 
(DLPack contract).
         let shape = if ndim > 0 && !dl_tensor.shape.is_null() {
             std::slice::from_raw_parts(dl_tensor.shape, ndim).to_vec()
         } else {
             vec![]
         };
 
-        // Extract device_id
+        if ndim == 0 || shape.is_empty() {
+            return Err(PyRuntimeError::new_err(
+                "DLPack tensor must have at least 1 dimension",
+            ));
+        }
+
+        if !dl_tensor.strides.is_null() {
+            // SAFETY: strides pointer is valid for ndim elements (DLPack 
contract).
+            let strides = std::slice::from_raw_parts(dl_tensor.strides, ndim);
+            match ndim {
+                1 => {
+                    let expected = 1_i64;
+                    if strides[0] != expected {
+                        return Err(PyRuntimeError::new_err(format!(
+                            "DLPack tensor must be contiguous: stride[0]={}, 
expected {}",
+                            strides[0], expected
+                        )));
+                    }
+                }
+                2 => {
+                    if shape.len() < 2 {
+                        return Err(PyRuntimeError::new_err(
+                            "DLPack tensor must be contiguous (shape len < 2)",
+                        ));
+                    }
+                    let expected_stride_1 = 1_i64;
+                    let expected_stride_0 = shape[1];
+                    if strides[1] != expected_stride_1 || strides[0] != 
expected_stride_0 {
+                        return Err(PyRuntimeError::new_err(format!(
+                            "DLPack tensor must be contiguous: strides=[{}, 
{}], expected [{}, {}] (expected[1]=shape[1])",
+                            strides[0], strides[1], expected_stride_0, 
expected_stride_1
+                        )));
+                    }
+                }
+                _ => {
+                    return Err(PyRuntimeError::new_err(
+                        "DLPack tensor must be 1D or 2D for encoding",
+                    ));
+                }
+            }
+        }
+
         let device_id = dl_tensor.device.device_id;
 
-        // Rename the capsule to "used_dltensor" as per DLPack protocol
-        // This prevents PyTorch from trying to delete it when the capsule is 
garbage collected
         const USED_DLTENSOR_NAME: &[u8] = b"used_dltensor\0";
-        ffi::PyCapsule_SetName(capsule_ptr, USED_DLTENSOR_NAME.as_ptr() as 
*const i8);
+        // SAFETY: capsule is the same PyCapsule we used above; renaming is 
allowed and does not free it.
+        ffi::PyCapsule_SetName(capsule.as_ptr(), USED_DLTENSOR_NAME.as_ptr() 
as *const i8);
 
         Ok(DLPackTensorInfo {
             managed_ptr,
diff --git a/qdp/qdp-python/tests/test_dlpack_validation.py 
b/qdp/qdp-python/tests/test_dlpack_validation.py
new file mode 100644
index 000000000..84335c94a
--- /dev/null
+++ b/qdp/qdp-python/tests/test_dlpack_validation.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for DLPack validation in encode_from_pytorch 
(extract_dlpack_tensor)."""
+
+import pytest
+
+try:
+    import torch
+    from _qdp import QdpEngine
+except ImportError:
+    torch = None
+    QdpEngine = None
+
+
+def _cuda_available():
+    return torch is not None and torch.cuda.is_available()
+
+
+def _engine():
+    if QdpEngine is None:
+        pytest.skip("_qdp not built")
+    e = QdpEngine(0)
+    return e
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_dtype_validation_float32_rejected():
+    """DLPack tensor must be float64; float32 CUDA tensor should fail with 
clear message."""
+    engine = _engine()
+    # 1D float32 CUDA tensor (contiguous)
+    t = torch.randn(4, dtype=torch.float32, device="cuda")
+    with pytest.raises(RuntimeError) as exc_info:
+        engine.encode(t, num_qubits=2, encoding_method="amplitude")
+    msg = str(exc_info.value).lower()
+    assert "float64" in msg
+    assert "code=" in msg or "bits=" in msg or "lanes=" in msg
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_stride_1d_non_contiguous_rejected():
+    """Non-contiguous 1D CUDA tensor (stride != 1) should fail with actual vs 
expected."""
+    engine = _engine()
+    # Slice so stride is 2: shape (2,), stride (2,)
+    t = torch.randn(4, dtype=torch.float64, device="cuda")[::2]
+    assert t.stride(0) != 1
+    with pytest.raises(RuntimeError) as exc_info:
+        engine.encode(t, num_qubits=1, encoding_method="amplitude")
+    msg = str(exc_info.value)
+    assert "contiguous" in msg.lower()
+    assert "stride[0]=" in msg
+    assert "expected 1" in msg or "expected 1 " in msg
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_stride_2d_non_contiguous_rejected():
+    """Non-contiguous 2D CUDA tensor should fail with actual vs expected 
strides."""
+    engine = _engine()
+    # (4, 2) with strides (3, 2) -> not C-contiguous; expected for (4,2) is 
(2, 1)
+    t = torch.randn(4, 3, dtype=torch.float64, device="cuda")[:, ::2]
+    assert t.dim() == 2 and t.shape == (4, 2)
+    # Strides should be (3, 2) not (2, 1)
+    assert t.stride(0) == 3 and t.stride(1) == 2
+    with pytest.raises(RuntimeError) as exc_info:
+        engine.encode(t, num_qubits=1, encoding_method="amplitude")
+    msg = str(exc_info.value)
+    assert "contiguous" in msg.lower()
+    assert "strides=" in msg
+    assert "expected" in msg
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_valid_cuda_float64_1d_succeeds():
+    """Valid 1D float64 contiguous CUDA tensor should encode successfully."""
+    engine = _engine()
+    t = torch.randn(4, dtype=torch.float64, device="cuda")
+    result = engine.encode(t, num_qubits=2, encoding_method="amplitude")
+    assert result is not None
+
+
[email protected](not _cuda_available(), reason="CUDA not available")
+def test_valid_cuda_float64_2d_succeeds():
+    """Valid 2D float64 contiguous CUDA tensor should encode successfully."""
+    engine = _engine()
+    t = torch.randn(3, 4, dtype=torch.float64, device="cuda")
+    result = engine.encode(t, num_qubits=2, encoding_method="amplitude")
+    assert result is not None

Reply via email to