This is an automated email from the ASF dual-hosted git repository.
guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new 34e4a8bea [QDP] PyTorch GPU Pointer Validation (#929)
34e4a8bea is described below
commit 34e4a8bea58e7c61cbdb8f727caf3ccb7fc99c06
Author: Jie-Kai Chang <[email protected]>
AuthorDate: Fri Jan 30 13:57:35 2026 +0800
[QDP] PyTorch GPU Pointer Validation (#929)
* GPU pointer validation
Signed-off-by: 400Ping <[email protected]>
* update
Signed-off-by: 400Ping <[email protected]>
* fix conflicts
Signed-off-by: 400Ping <[email protected]>
* fix ci error
Signed-off-by: 400Ping <[email protected]>
* add unit test
Signed-off-by: 400Ping <[email protected]>
* fix pre-commit
Signed-off-by: 400Ping <[email protected]>
---------
Signed-off-by: 400Ping <[email protected]>
Signed-off-by: 400Ping <[email protected]>
---
qdp/qdp-core/src/dlpack.rs | 83 ++++++++++++++++++++++++++++++++++++++++
qdp/qdp-core/src/gpu/cuda_ffi.rs | 17 ++++++++
qdp/qdp-core/src/lib.rs | 55 ++++++++++++++++++++++++++
qdp/qdp-core/tests/dlpack.rs | 28 ++++++++++++++
qdp/qdp-python/src/lib.rs | 14 ++++++-
testing/qdp/test_bindings.py | 19 +++++++++
6 files changed, 214 insertions(+), 2 deletions(-)
diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index 4d3ac764d..eb3e596be 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -16,10 +16,93 @@
// DLPack protocol for zero-copy GPU memory sharing with PyTorch
+use crate::error::Result;
+#[cfg(target_os = "linux")]
+use crate::error::{MahoutError, cuda_error_to_string};
use crate::gpu::memory::{BufferStorage, GpuStateVector, Precision};
use std::os::raw::{c_int, c_void};
use std::sync::Arc;
+#[cfg(target_os = "linux")]
+use crate::gpu::cuda_ffi::{
+ CUDA_EVENT_DISABLE_TIMING, cudaEventCreateWithFlags, cudaEventDestroy,
cudaEventRecord,
+ cudaStreamWaitEvent,
+};
+
+/// DLPack CUDA stream sentinel values (legacy/per-thread default).
+/// These match cudaStreamLegacy/cudaStreamPerThread in the CUDA runtime.
+#[allow(clippy::manual_dangling_ptr)]
+pub const CUDA_STREAM_LEGACY: *mut c_void = 1 as *mut c_void;
+#[allow(clippy::manual_dangling_ptr)]
+pub const CUDA_STREAM_PER_THREAD: *mut c_void = 2 as *mut c_void;
+
+/// Map DLPack stream integer to a CUDA stream pointer.
+pub fn dlpack_stream_to_cuda(stream: i64) -> *mut c_void {
+ match stream {
+ 1 => CUDA_STREAM_LEGACY,
+ 2 => CUDA_STREAM_PER_THREAD,
+ _ => stream as *mut c_void,
+ }
+}
+
+#[cfg(target_os = "linux")]
+/// # Safety
+/// `stream` must be a valid CUDA stream pointer or one of the CUDA sentinel
+/// values (legacy/per-thread default). Passing any other pointer is undefined.
+pub unsafe fn synchronize_stream(stream: *mut c_void) -> Result<()> {
+ if stream.is_null() {
+ return Ok(());
+ }
+
+ let mut event: *mut c_void = std::ptr::null_mut();
+ let ret = unsafe { cudaEventCreateWithFlags(&mut event,
CUDA_EVENT_DISABLE_TIMING) };
+ if ret != 0 {
+ return Err(MahoutError::Cuda(format!(
+ "cudaEventCreateWithFlags failed: {} ({})",
+ ret,
+ cuda_error_to_string(ret)
+ )));
+ }
+
+ let record_ret = unsafe { cudaEventRecord(event, std::ptr::null_mut()) };
+ if record_ret != 0 {
+ let _ = unsafe { cudaEventDestroy(event) };
+ return Err(MahoutError::Cuda(format!(
+ "cudaEventRecord failed: {} ({})",
+ record_ret,
+ cuda_error_to_string(record_ret)
+ )));
+ }
+
+ let wait_ret = unsafe { cudaStreamWaitEvent(stream, event, 0) };
+ if wait_ret != 0 {
+ let _ = unsafe { cudaEventDestroy(event) };
+ return Err(MahoutError::Cuda(format!(
+ "cudaStreamWaitEvent failed: {} ({})",
+ wait_ret,
+ cuda_error_to_string(wait_ret)
+ )));
+ }
+
+ let destroy_ret = unsafe { cudaEventDestroy(event) };
+ if destroy_ret != 0 {
+ return Err(MahoutError::Cuda(format!(
+ "cudaEventDestroy failed: {} ({})",
+ destroy_ret,
+ cuda_error_to_string(destroy_ret)
+ )));
+ }
+
+ Ok(())
+}
+
+#[cfg(not(target_os = "linux"))]
+/// # Safety
+/// No-op on non-Linux targets, kept unsafe to match the Linux signature.
+pub unsafe fn synchronize_stream(_stream: *mut c_void) -> Result<()> {
+ Ok(())
+}
+
// DLPack C structures (matching dlpack/dlpack.h)
#[repr(C)]
diff --git a/qdp/qdp-core/src/gpu/cuda_ffi.rs b/qdp/qdp-core/src/gpu/cuda_ffi.rs
index 6d8e4cb74..491e1382b 100644
--- a/qdp/qdp-core/src/gpu/cuda_ffi.rs
+++ b/qdp/qdp-core/src/gpu/cuda_ffi.rs
@@ -21,6 +21,18 @@ use std::ffi::c_void;
pub(crate) const CUDA_MEMCPY_HOST_TO_DEVICE: u32 = 1;
pub(crate) const CUDA_EVENT_DISABLE_TIMING: u32 = 0x02;
pub(crate) const CUDA_EVENT_DEFAULT: u32 = 0x00;
+pub(crate) const CUDA_MEMORY_TYPE_DEVICE: i32 = 2;
+pub(crate) const CUDA_MEMORY_TYPE_MANAGED: i32 = 3;
+
+#[repr(C)]
+pub(crate) struct CudaPointerAttributes {
+ pub memory_type: i32,
+ pub device: i32,
+ pub device_pointer: *mut c_void,
+ pub host_pointer: *mut c_void,
+ pub is_managed: i32,
+ pub allocation_flags: u32,
+}
// CUDA error codes
pub(crate) const CUDA_SUCCESS: i32 = 0;
@@ -33,6 +45,11 @@ unsafe extern "C" {
pub(crate) fn cudaHostAlloc(pHost: *mut *mut c_void, size: usize, flags:
u32) -> i32;
pub(crate) fn cudaFreeHost(ptr: *mut c_void) -> i32;
+ pub(crate) fn cudaPointerGetAttributes(
+ attributes: *mut CudaPointerAttributes,
+ ptr: *const c_void,
+ ) -> i32;
+
pub(crate) fn cudaMemGetInfo(free: *mut usize, total: *mut usize) -> i32;
pub(crate) fn cudaMemcpyAsync(
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index f6b58d50a..9a5290447 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -41,6 +41,57 @@ use crate::dlpack::DLManagedTensor;
use crate::gpu::get_encoder;
use cudarc::driver::CudaDevice;
+#[cfg(target_os = "linux")]
+fn validate_cuda_input_ptr(device: &CudaDevice, ptr: *const f64) -> Result<()>
{
+ use crate::gpu::cuda_ffi::{
+ CUDA_MEMORY_TYPE_DEVICE, CUDA_MEMORY_TYPE_MANAGED,
CudaPointerAttributes,
+ cudaPointerGetAttributes,
+ };
+ use std::ffi::c_void;
+
+ if ptr.is_null() {
+ return Err(MahoutError::InvalidInput(
+ "Input GPU pointer is null".to_string(),
+ ));
+ }
+
+ let mut attrs = CudaPointerAttributes {
+ memory_type: 0,
+ device: 0,
+ device_pointer: std::ptr::null_mut(),
+ host_pointer: std::ptr::null_mut(),
+ is_managed: 0,
+ allocation_flags: 0,
+ };
+
+ let ret = unsafe { cudaPointerGetAttributes(&mut attrs as *mut _, ptr as
*const c_void) };
+ if ret != 0 {
+ return Err(MahoutError::InvalidInput(format!(
+ "cudaPointerGetAttributes failed for input pointer: {} ({})",
+ ret,
+ cuda_error_to_string(ret)
+ )));
+ }
+
+ if attrs.memory_type != CUDA_MEMORY_TYPE_DEVICE && attrs.memory_type !=
CUDA_MEMORY_TYPE_MANAGED
+ {
+ return Err(MahoutError::InvalidInput(format!(
+ "Input pointer is not CUDA device memory (memory_type={})",
+ attrs.memory_type
+ )));
+ }
+
+ let device_ordinal = device.ordinal() as i32;
+ if attrs.device >= 0 && attrs.device != device_ordinal {
+ return Err(MahoutError::InvalidInput(format!(
+ "Input pointer device mismatch: pointer on cuda:{}, engine on
cuda:{}",
+ attrs.device, device_ordinal
+ )));
+ }
+
+ Ok(())
+}
+
/// Main entry point for Mahout QDP
///
/// Manages GPU context and dispatches encoding tasks.
@@ -345,6 +396,8 @@ impl QdpEngine {
));
}
+ validate_cuda_input_ptr(&self.device, input_d)?;
+
let state_len = 1usize << num_qubits;
let method = encoding_method.to_ascii_lowercase();
@@ -513,6 +566,8 @@ impl QdpEngine {
));
}
+ validate_cuda_input_ptr(&self.device, input_batch_d)?;
+
let state_len = 1usize << num_qubits;
let method = encoding_method.to_ascii_lowercase();
diff --git a/qdp/qdp-core/tests/dlpack.rs b/qdp/qdp-core/tests/dlpack.rs
index f42474a65..6b97283ce 100644
--- a/qdp/qdp-core/tests/dlpack.rs
+++ b/qdp/qdp-core/tests/dlpack.rs
@@ -18,7 +18,10 @@
#[cfg(test)]
mod dlpack_tests {
+ use std::ffi::c_void;
+
use cudarc::driver::CudaDevice;
+ use qdp_core::dlpack::{CUDA_STREAM_LEGACY, synchronize_stream};
use qdp_core::gpu::memory::GpuStateVector;
#[test]
@@ -82,4 +85,29 @@ mod dlpack_tests {
}
}
}
+
+ /// synchronize_stream(null) is a no-op and returns Ok(()) on all
platforms.
+ #[test]
+ fn test_synchronize_stream_null() {
+ unsafe {
+ let result = synchronize_stream(std::ptr::null_mut::<c_void>());
+ assert!(
+ result.is_ok(),
+ "synchronize_stream(null) should return Ok(())"
+ );
+ }
+ }
+
+ /// synchronize_stream(CUDA_STREAM_LEGACY) syncs the legacy default stream
(Linux + CUDA).
+ #[test]
+ #[cfg(target_os = "linux")]
+ fn test_synchronize_stream_legacy() {
+ unsafe {
+ let result = synchronize_stream(CUDA_STREAM_LEGACY);
+ assert!(
+ result.is_ok(),
+ "synchronize_stream(CUDA_STREAM_LEGACY) should succeed on
Linux with CUDA"
+ );
+ }
+ }
}
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 8647d98b7..1af58a617 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -45,7 +45,7 @@ impl QuantumTensor {
/// The capsule can only be consumed once to prevent double-free errors.
///
/// Args:
- /// stream: Optional CUDA stream pointer (for DLPack 0.8+)
+ /// stream: Optional CUDA stream (DLPack 0.8+; 1=legacy default,
2=per-thread default)
///
/// Returns:
/// PyCapsule containing DLManagedTensor pointer
@@ -54,7 +54,6 @@ impl QuantumTensor {
/// RuntimeError: If the tensor has already been consumed
#[pyo3(signature = (stream=None))]
fn __dlpack__<'py>(&mut self, py: Python<'py>, stream: Option<i64>) ->
PyResult<Py<PyAny>> {
- let _ = stream; // Suppress unused variable warning
if self.consumed {
return Err(PyRuntimeError::new_err(
"DLPack tensor already consumed (can only be used once)",
@@ -65,6 +64,17 @@ impl QuantumTensor {
return Err(PyRuntimeError::new_err("Invalid DLPack tensor
pointer"));
}
+ if let Some(stream) = stream
+ && stream > 0
+ {
+ let stream_ptr = qdp_core::dlpack::dlpack_stream_to_cuda(stream);
+ unsafe {
+ qdp_core::dlpack::synchronize_stream(stream_ptr).map_err(|e| {
+ PyRuntimeError::new_err(format!("CUDA stream sync failed:
{}", e))
+ })?;
+ }
+ }
+
// Mark as consumed to prevent double-free
self.consumed = true;
diff --git a/testing/qdp/test_bindings.py b/testing/qdp/test_bindings.py
index 58f5d5af6..d213d55cd 100644
--- a/testing/qdp/test_bindings.py
+++ b/testing/qdp/test_bindings.py
@@ -125,6 +125,25 @@ def test_dlpack_single_use():
qtensor2.__dlpack__()
+@requires_qdp
[email protected]
[email protected]("stream", [1, 2], ids=["stream_legacy",
"stream_per_thread"])
+def test_dlpack_with_stream(stream):
+ """Test __dlpack__(stream=...) syncs CUDA stream before returning capsule
(DLPack 0.8+)."""
+ import torch
+ from _qdp import QdpEngine
+
+ engine = QdpEngine(0)
+ data = [1.0, 2.0, 3.0, 4.0]
+ qtensor = engine.encode(data, 2, "amplitude")
+
+ # stream=1 (legacy default) or 2 (per-thread default) should sync and
return capsule
+ capsule = qtensor.__dlpack__(stream=stream)
+ torch_tensor = torch.from_dlpack(capsule)
+ assert torch_tensor.is_cuda
+ assert torch_tensor.shape == (1, 4)
+
+
@requires_qdp
@pytest.mark.gpu
def test_pytorch_integration():