This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new 34e4a8bea [QDP] PyTorch GPU Pointer Validation (#929)
34e4a8bea is described below

commit 34e4a8bea58e7c61cbdb8f727caf3ccb7fc99c06
Author: Jie-Kai Chang <[email protected]>
AuthorDate: Fri Jan 30 13:57:35 2026 +0800

    [QDP] PyTorch GPU Pointer Validation (#929)
    
    * GPU pointer validation
    
    Signed-off-by: 400Ping <[email protected]>
    
    * update
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix conflicts
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix ci error
    
    Signed-off-by: 400Ping <[email protected]>
    
    * add unit test
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix pre-commit
    
    Signed-off-by: 400Ping <[email protected]>
    
    ---------
    
    Signed-off-by: 400Ping <[email protected]>
    Signed-off-by: 400Ping <[email protected]>
---
 qdp/qdp-core/src/dlpack.rs       | 83 ++++++++++++++++++++++++++++++++++++++++
 qdp/qdp-core/src/gpu/cuda_ffi.rs | 17 ++++++++
 qdp/qdp-core/src/lib.rs          | 55 ++++++++++++++++++++++++++
 qdp/qdp-core/tests/dlpack.rs     | 28 ++++++++++++++
 qdp/qdp-python/src/lib.rs        | 14 ++++++-
 testing/qdp/test_bindings.py     | 19 +++++++++
 6 files changed, 214 insertions(+), 2 deletions(-)

diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index 4d3ac764d..eb3e596be 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -16,10 +16,93 @@
 
 // DLPack protocol for zero-copy GPU memory sharing with PyTorch
 
+use crate::error::Result;
+#[cfg(target_os = "linux")]
+use crate::error::{MahoutError, cuda_error_to_string};
 use crate::gpu::memory::{BufferStorage, GpuStateVector, Precision};
 use std::os::raw::{c_int, c_void};
 use std::sync::Arc;
 
+#[cfg(target_os = "linux")]
+use crate::gpu::cuda_ffi::{
+    CUDA_EVENT_DISABLE_TIMING, cudaEventCreateWithFlags, cudaEventDestroy, 
cudaEventRecord,
+    cudaStreamWaitEvent,
+};
+
+/// DLPack CUDA stream sentinel values (legacy/per-thread default).
+/// These match cudaStreamLegacy/cudaStreamPerThread in the CUDA runtime.
+#[allow(clippy::manual_dangling_ptr)]
+pub const CUDA_STREAM_LEGACY: *mut c_void = 1 as *mut c_void;
+#[allow(clippy::manual_dangling_ptr)]
+pub const CUDA_STREAM_PER_THREAD: *mut c_void = 2 as *mut c_void;
+
+/// Map DLPack stream integer to a CUDA stream pointer.
+pub fn dlpack_stream_to_cuda(stream: i64) -> *mut c_void {
+    match stream {
+        1 => CUDA_STREAM_LEGACY,
+        2 => CUDA_STREAM_PER_THREAD,
+        _ => stream as *mut c_void,
+    }
+}
+
+#[cfg(target_os = "linux")]
+/// # Safety
+/// `stream` must be a valid CUDA stream pointer or one of the CUDA sentinel
+/// values (legacy/per-thread default). Passing any other pointer is undefined.
+pub unsafe fn synchronize_stream(stream: *mut c_void) -> Result<()> {
+    if stream.is_null() {
+        return Ok(());
+    }
+
+    let mut event: *mut c_void = std::ptr::null_mut();
+    let ret = unsafe { cudaEventCreateWithFlags(&mut event, 
CUDA_EVENT_DISABLE_TIMING) };
+    if ret != 0 {
+        return Err(MahoutError::Cuda(format!(
+            "cudaEventCreateWithFlags failed: {} ({})",
+            ret,
+            cuda_error_to_string(ret)
+        )));
+    }
+
+    let record_ret = unsafe { cudaEventRecord(event, std::ptr::null_mut()) };
+    if record_ret != 0 {
+        let _ = unsafe { cudaEventDestroy(event) };
+        return Err(MahoutError::Cuda(format!(
+            "cudaEventRecord failed: {} ({})",
+            record_ret,
+            cuda_error_to_string(record_ret)
+        )));
+    }
+
+    let wait_ret = unsafe { cudaStreamWaitEvent(stream, event, 0) };
+    if wait_ret != 0 {
+        let _ = unsafe { cudaEventDestroy(event) };
+        return Err(MahoutError::Cuda(format!(
+            "cudaStreamWaitEvent failed: {} ({})",
+            wait_ret,
+            cuda_error_to_string(wait_ret)
+        )));
+    }
+
+    let destroy_ret = unsafe { cudaEventDestroy(event) };
+    if destroy_ret != 0 {
+        return Err(MahoutError::Cuda(format!(
+            "cudaEventDestroy failed: {} ({})",
+            destroy_ret,
+            cuda_error_to_string(destroy_ret)
+        )));
+    }
+
+    Ok(())
+}
+
+#[cfg(not(target_os = "linux"))]
+/// # Safety
+/// No-op on non-Linux targets, kept unsafe to match the Linux signature.
+pub unsafe fn synchronize_stream(_stream: *mut c_void) -> Result<()> {
+    Ok(())
+}
+
 // DLPack C structures (matching dlpack/dlpack.h)
 
 #[repr(C)]
diff --git a/qdp/qdp-core/src/gpu/cuda_ffi.rs b/qdp/qdp-core/src/gpu/cuda_ffi.rs
index 6d8e4cb74..491e1382b 100644
--- a/qdp/qdp-core/src/gpu/cuda_ffi.rs
+++ b/qdp/qdp-core/src/gpu/cuda_ffi.rs
@@ -21,6 +21,18 @@ use std::ffi::c_void;
 pub(crate) const CUDA_MEMCPY_HOST_TO_DEVICE: u32 = 1;
 pub(crate) const CUDA_EVENT_DISABLE_TIMING: u32 = 0x02;
 pub(crate) const CUDA_EVENT_DEFAULT: u32 = 0x00;
+pub(crate) const CUDA_MEMORY_TYPE_DEVICE: i32 = 2;
+pub(crate) const CUDA_MEMORY_TYPE_MANAGED: i32 = 3;
+
+#[repr(C)]
+pub(crate) struct CudaPointerAttributes {
+    pub memory_type: i32,
+    pub device: i32,
+    pub device_pointer: *mut c_void,
+    pub host_pointer: *mut c_void,
+    pub is_managed: i32,
+    pub allocation_flags: u32,
+}
 
 // CUDA error codes
 pub(crate) const CUDA_SUCCESS: i32 = 0;
@@ -33,6 +45,11 @@ unsafe extern "C" {
     pub(crate) fn cudaHostAlloc(pHost: *mut *mut c_void, size: usize, flags: 
u32) -> i32;
     pub(crate) fn cudaFreeHost(ptr: *mut c_void) -> i32;
 
+    pub(crate) fn cudaPointerGetAttributes(
+        attributes: *mut CudaPointerAttributes,
+        ptr: *const c_void,
+    ) -> i32;
+
     pub(crate) fn cudaMemGetInfo(free: *mut usize, total: *mut usize) -> i32;
 
     pub(crate) fn cudaMemcpyAsync(
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index f6b58d50a..9a5290447 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -41,6 +41,57 @@ use crate::dlpack::DLManagedTensor;
 use crate::gpu::get_encoder;
 use cudarc::driver::CudaDevice;
 
+#[cfg(target_os = "linux")]
+fn validate_cuda_input_ptr(device: &CudaDevice, ptr: *const f64) -> Result<()> 
{
+    use crate::gpu::cuda_ffi::{
+        CUDA_MEMORY_TYPE_DEVICE, CUDA_MEMORY_TYPE_MANAGED, 
CudaPointerAttributes,
+        cudaPointerGetAttributes,
+    };
+    use std::ffi::c_void;
+
+    if ptr.is_null() {
+        return Err(MahoutError::InvalidInput(
+            "Input GPU pointer is null".to_string(),
+        ));
+    }
+
+    let mut attrs = CudaPointerAttributes {
+        memory_type: 0,
+        device: 0,
+        device_pointer: std::ptr::null_mut(),
+        host_pointer: std::ptr::null_mut(),
+        is_managed: 0,
+        allocation_flags: 0,
+    };
+
+    let ret = unsafe { cudaPointerGetAttributes(&mut attrs as *mut _, ptr as 
*const c_void) };
+    if ret != 0 {
+        return Err(MahoutError::InvalidInput(format!(
+            "cudaPointerGetAttributes failed for input pointer: {} ({})",
+            ret,
+            cuda_error_to_string(ret)
+        )));
+    }
+
+    if attrs.memory_type != CUDA_MEMORY_TYPE_DEVICE && attrs.memory_type != 
CUDA_MEMORY_TYPE_MANAGED
+    {
+        return Err(MahoutError::InvalidInput(format!(
+            "Input pointer is not CUDA device memory (memory_type={})",
+            attrs.memory_type
+        )));
+    }
+
+    let device_ordinal = device.ordinal() as i32;
+    if attrs.device >= 0 && attrs.device != device_ordinal {
+        return Err(MahoutError::InvalidInput(format!(
+            "Input pointer device mismatch: pointer on cuda:{}, engine on 
cuda:{}",
+            attrs.device, device_ordinal
+        )));
+    }
+
+    Ok(())
+}
+
 /// Main entry point for Mahout QDP
 ///
 /// Manages GPU context and dispatches encoding tasks.
@@ -345,6 +396,8 @@ impl QdpEngine {
             ));
         }
 
+        validate_cuda_input_ptr(&self.device, input_d)?;
+
         let state_len = 1usize << num_qubits;
         let method = encoding_method.to_ascii_lowercase();
 
@@ -513,6 +566,8 @@ impl QdpEngine {
             ));
         }
 
+        validate_cuda_input_ptr(&self.device, input_batch_d)?;
+
         let state_len = 1usize << num_qubits;
         let method = encoding_method.to_ascii_lowercase();
 
diff --git a/qdp/qdp-core/tests/dlpack.rs b/qdp/qdp-core/tests/dlpack.rs
index f42474a65..6b97283ce 100644
--- a/qdp/qdp-core/tests/dlpack.rs
+++ b/qdp/qdp-core/tests/dlpack.rs
@@ -18,7 +18,10 @@
 
 #[cfg(test)]
 mod dlpack_tests {
+    use std::ffi::c_void;
+
     use cudarc::driver::CudaDevice;
+    use qdp_core::dlpack::{CUDA_STREAM_LEGACY, synchronize_stream};
     use qdp_core::gpu::memory::GpuStateVector;
 
     #[test]
@@ -82,4 +85,29 @@ mod dlpack_tests {
             }
         }
     }
+
+    /// synchronize_stream(null) is a no-op and returns Ok(()) on all 
platforms.
+    #[test]
+    fn test_synchronize_stream_null() {
+        unsafe {
+            let result = synchronize_stream(std::ptr::null_mut::<c_void>());
+            assert!(
+                result.is_ok(),
+                "synchronize_stream(null) should return Ok(())"
+            );
+        }
+    }
+
+    /// synchronize_stream(CUDA_STREAM_LEGACY) syncs the legacy default stream 
(Linux + CUDA).
+    #[test]
+    #[cfg(target_os = "linux")]
+    fn test_synchronize_stream_legacy() {
+        unsafe {
+            let result = synchronize_stream(CUDA_STREAM_LEGACY);
+            assert!(
+                result.is_ok(),
+                "synchronize_stream(CUDA_STREAM_LEGACY) should succeed on 
Linux with CUDA"
+            );
+        }
+    }
 }
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 8647d98b7..1af58a617 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -45,7 +45,7 @@ impl QuantumTensor {
     /// The capsule can only be consumed once to prevent double-free errors.
     ///
     /// Args:
-    ///     stream: Optional CUDA stream pointer (for DLPack 0.8+)
+    ///     stream: Optional CUDA stream (DLPack 0.8+; 1=legacy default, 
2=per-thread default)
     ///
     /// Returns:
     ///     PyCapsule containing DLManagedTensor pointer
@@ -54,7 +54,6 @@ impl QuantumTensor {
     ///     RuntimeError: If the tensor has already been consumed
     #[pyo3(signature = (stream=None))]
     fn __dlpack__<'py>(&mut self, py: Python<'py>, stream: Option<i64>) -> 
PyResult<Py<PyAny>> {
-        let _ = stream; // Suppress unused variable warning
         if self.consumed {
             return Err(PyRuntimeError::new_err(
                 "DLPack tensor already consumed (can only be used once)",
@@ -65,6 +64,17 @@ impl QuantumTensor {
             return Err(PyRuntimeError::new_err("Invalid DLPack tensor 
pointer"));
         }
 
+        if let Some(stream) = stream
+            && stream > 0
+        {
+            let stream_ptr = qdp_core::dlpack::dlpack_stream_to_cuda(stream);
+            unsafe {
+                qdp_core::dlpack::synchronize_stream(stream_ptr).map_err(|e| {
+                    PyRuntimeError::new_err(format!("CUDA stream sync failed: 
{}", e))
+                })?;
+            }
+        }
+
         // Mark as consumed to prevent double-free
         self.consumed = true;
 
diff --git a/testing/qdp/test_bindings.py b/testing/qdp/test_bindings.py
index 58f5d5af6..d213d55cd 100644
--- a/testing/qdp/test_bindings.py
+++ b/testing/qdp/test_bindings.py
@@ -125,6 +125,25 @@ def test_dlpack_single_use():
         qtensor2.__dlpack__()
 
 
+@requires_qdp
[email protected]
[email protected]("stream", [1, 2], ids=["stream_legacy", 
"stream_per_thread"])
+def test_dlpack_with_stream(stream):
+    """Test __dlpack__(stream=...) syncs CUDA stream before returning capsule 
(DLPack 0.8+)."""
+    import torch
+    from _qdp import QdpEngine
+
+    engine = QdpEngine(0)
+    data = [1.0, 2.0, 3.0, 4.0]
+    qtensor = engine.encode(data, 2, "amplitude")
+
+    # stream=1 (legacy default) or 2 (per-thread default) should sync and 
return capsule
+    capsule = qtensor.__dlpack__(stream=stream)
+    torch_tensor = torch.from_dlpack(capsule)
+    assert torch_tensor.is_cuda
+    assert torch_tensor.shape == (1, 4)
+
+
 @requires_qdp
 @pytest.mark.gpu
 def test_pytorch_integration():

Reply via email to