This is an automated email from the ASF dual-hosted git repository.
guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new dfe118e18 [QDP] DLPack Tensor unsafe cleanup refactoring (#1011)
dfe118e18 is described below
commit dfe118e18efb7e9c0f9c90ec22b09dd27394ba52
Author: ChenChen Lai <[email protected]>
AuthorDate: Thu Mar 5 10:30:56 2026 +0800
[QDP] DLPack Tensor unsafe cleanup refactoring (#1011)
* DLPack Tensor Cleanup Refactoring
* add test for free_dlpack_tensor
---------
Co-authored-by: Ryan Huang <[email protected]>
Co-authored-by: Guan-Ming (Wesley) Chiu
<[email protected]>
---
qdp/qdp-core/examples/dataloader_throughput.rs | 14 ++-
qdp/qdp-core/examples/nvtx_profile.rs | 12 +--
qdp/qdp-core/examples/observability_test.rs | 11 ++-
qdp/qdp-core/src/dlpack.rs | 43 +++++++-
qdp/qdp-core/tests/dlpack.rs | 132 ++++++++++++++++++++++++-
5 files changed, 192 insertions(+), 20 deletions(-)
diff --git a/qdp/qdp-core/examples/dataloader_throughput.rs
b/qdp/qdp-core/examples/dataloader_throughput.rs
index d3cb1ea82..cdbdd21d4 100644
--- a/qdp/qdp-core/examples/dataloader_throughput.rs
+++ b/qdp/qdp-core/examples/dataloader_throughput.rs
@@ -24,6 +24,7 @@ use std::thread;
use std::time::{Duration, Instant};
use qdp_core::QdpEngine;
+use qdp_core::dlpack::free_dlpack_tensor;
const BATCH_SIZE: usize = 64;
const VECTOR_LEN: usize = 1024; // 2^10
@@ -99,12 +100,15 @@ fn main() {
debug_assert_eq!(batch.len() % VECTOR_LEN, 0);
let num_samples = batch.len() / VECTOR_LEN;
match engine.encode_batch(&batch, num_samples, VECTOR_LEN, NUM_QUBITS,
"amplitude") {
- Ok(ptr) => unsafe {
- let managed = &mut *ptr;
- if let Some(deleter) = managed.deleter.take() {
- deleter(ptr);
+ Ok(ptr) => {
+ if let Err(e) = unsafe { free_dlpack_tensor(ptr) } {
+ eprintln!(
+ "Failed to free DLPack tensor for batch {} (processed
{} vectors): {:?}",
+ batch_idx, total_vectors, e
+ );
+ return;
}
- },
+ }
Err(e) => {
eprintln!(
"Encode batch failed on batch {} (processed {} vectors):
{:?}",
diff --git a/qdp/qdp-core/examples/nvtx_profile.rs
b/qdp/qdp-core/examples/nvtx_profile.rs
index 87ceeff2b..d869a6502 100644
--- a/qdp/qdp-core/examples/nvtx_profile.rs
+++ b/qdp/qdp-core/examples/nvtx_profile.rs
@@ -18,6 +18,7 @@
// Run: cargo run -p qdp-core --example nvtx_profile --features observability
--release
use qdp_core::QdpEngine;
+use qdp_core::dlpack::free_dlpack_tensor;
fn main() {
println!("=== NVTX Profiling Example ===");
@@ -61,13 +62,10 @@ fn main() {
println!("✓ Encoding succeeded");
println!("✓ DLPack pointer: {:p}", ptr);
- // Clean up
- unsafe {
- let managed = &mut *ptr;
- if let Some(deleter) = managed.deleter.take() {
- deleter(ptr);
- println!("✓ Memory freed");
- }
+ // Clean up using shared helper with safety checks
+ match unsafe { free_dlpack_tensor(ptr) } {
+ Ok(()) => println!("✓ Memory freed"),
+ Err(e) => eprintln!("✗ Failed to free DLPack tensor: {:?}", e),
}
}
Err(e) => {
diff --git a/qdp/qdp-core/examples/observability_test.rs
b/qdp/qdp-core/examples/observability_test.rs
index 462e8aef7..5af7d7c1c 100644
--- a/qdp/qdp-core/examples/observability_test.rs
+++ b/qdp/qdp-core/examples/observability_test.rs
@@ -19,6 +19,7 @@
// Run: cargo run -p qdp-core --example observability_test --release
use qdp_core::QdpEngine;
+use qdp_core::dlpack::free_dlpack_tensor;
use std::env;
fn main() {
@@ -92,12 +93,12 @@ fn main() {
for i in 0..NUM_SAMPLES {
let sample = &test_data[i * VECTOR_LEN..(i + 1) * VECTOR_LEN];
match engine.encode(sample, NUM_QUBITS, "amplitude") {
- Ok(ptr) => unsafe {
- let managed = &mut *ptr;
- if let Some(deleter) = managed.deleter.take() {
- deleter(ptr);
+ Ok(ptr) => {
+ if let Err(e) = unsafe { free_dlpack_tensor(ptr) } {
+ eprintln!("✗ Failed to free DLPack tensor for sample {}:
{:?}", i, e);
+ return;
}
- },
+ }
Err(e) => {
eprintln!("✗ Encoding failed for sample {}: {:?}", i, e);
return;
diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index 1780fa404..1181dbd8e 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -16,9 +16,9 @@
// DLPack protocol for zero-copy GPU memory sharing with PyTorch
-use crate::error::Result;
#[cfg(target_os = "linux")]
-use crate::error::{MahoutError, cuda_error_to_string};
+use crate::error::cuda_error_to_string;
+use crate::error::{MahoutError, Result};
use crate::gpu::memory::{BufferStorage, GpuStateVector, Precision};
use std::os::raw::{c_int, c_void};
use std::sync::Arc;
@@ -205,6 +205,45 @@ pub unsafe extern "C" fn dlpack_deleter(managed: *mut
DLManagedTensor) {
let _ = Box::from_raw(managed);
}
+/// Safely free a `DLManagedTensor` pointer returned from encoding APIs.
+///
+/// This helper function centralizes the unsafe pointer dereference and deleter
+/// invocation logic, adding safety checks to prevent common errors like null
+/// pointer dereference and double-free.
+///
+/// # Safety
+/// The caller must ensure:
+/// - `ptr` is a valid `DLManagedTensor` pointer returned from
`QdpEngine::encode()`
+/// or similar methods, or is null
+/// - The pointer has not been freed before (either by calling this function
+/// or by PyTorch's DLPack deleter)
+/// - The pointer is not used after this call
+///
+/// # Errors
+/// Returns `Err` if:
+/// - The pointer is null
+/// - The deleter is missing or has already been called
+#[allow(unsafe_op_in_unsafe_fn)]
+pub unsafe fn free_dlpack_tensor(ptr: *mut DLManagedTensor) -> Result<()> {
+ if ptr.is_null() {
+ return Err(MahoutError::InvalidInput(
+ "DLPack pointer is null (nothing to free)".into(),
+ ));
+ }
+
+ // SAFETY: Caller guarantees ptr is valid and not yet freed.
+ // We've checked it's not null above.
+ let managed = &mut *ptr;
+
+ let deleter = managed.deleter.take().ok_or_else(|| {
+ MahoutError::InvalidInput("DLPack deleter missing or already
called".into())
+ })?;
+
+ // Call the DLPack deleter to free memory
+ deleter(ptr);
+ Ok(())
+}
+
impl GpuStateVector {
/// Convert to DLPack format for PyTorch
///
diff --git a/qdp/qdp-core/tests/dlpack.rs b/qdp/qdp-core/tests/dlpack.rs
index c22dda384..29cfc74ce 100644
--- a/qdp/qdp-core/tests/dlpack.rs
+++ b/qdp/qdp-core/tests/dlpack.rs
@@ -21,8 +21,12 @@ mod dlpack_tests {
use std::ffi::c_void;
use cudarc::driver::CudaDevice;
+ use qdp_core::MahoutError;
use qdp_core::Precision;
- use qdp_core::dlpack::{CUDA_STREAM_LEGACY, synchronize_stream};
+ use qdp_core::dlpack::{
+ CUDA_STREAM_LEGACY, DL_FLOAT, DLDataType, DLDevice, DLDeviceType,
DLManagedTensor,
+ DLTensor, free_dlpack_tensor, synchronize_stream,
+ };
use qdp_core::gpu::memory::GpuStateVector;
#[test]
@@ -180,4 +184,130 @@ mod dlpack_tests {
);
}
}
+
+ /// free_dlpack_tensor(null) should return an InvalidInput error instead
of panicking.
+ #[test]
+ fn test_free_dlpack_tensor_null_ptr() {
+ unsafe {
+ let result = free_dlpack_tensor(std::ptr::null_mut());
+ match result {
+ Err(MahoutError::InvalidInput(msg)) => {
+ assert!(
+ msg.to_lowercase().contains("null"),
+ "Expected null-pointer error message, got: {}",
+ msg
+ );
+ }
+ other => panic!(
+ "Expected InvalidInput error for null pointer, got: {:?}",
+ other
+ ),
+ }
+ }
+ }
+
+ /// free_dlpack_tensor should detect missing deleter and return an
InvalidInput error.
+ #[test]
+ fn test_free_dlpack_tensor_missing_deleter() {
+ // Minimal, but structurally valid, DLTensor for constructing
DLManagedTensor.
+ let dummy_tensor = DLTensor {
+ data: std::ptr::null_mut(),
+ device: DLDevice {
+ device_type: DLDeviceType::kDLCPU,
+ device_id: 0,
+ },
+ ndim: 0,
+ dtype: DLDataType {
+ code: DL_FLOAT,
+ bits: 64,
+ lanes: 1,
+ },
+ shape: std::ptr::null_mut(),
+ strides: std::ptr::null_mut(),
+ byte_offset: 0,
+ };
+
+ let managed = DLManagedTensor {
+ dl_tensor: dummy_tensor,
+ manager_ctx: std::ptr::null_mut(),
+ deleter: None,
+ };
+
+ let ptr = Box::into_raw(Box::new(managed));
+
+ unsafe {
+ let result = free_dlpack_tensor(ptr);
+ match result {
+ Err(MahoutError::InvalidInput(msg)) => {
+ assert!(
+ msg.to_lowercase().contains("deleter"),
+ "Expected missing-deleter error message, got: {}",
+ msg
+ );
+ }
+ other => panic!(
+ "Expected InvalidInput error for missing deleter, got:
{:?}",
+ other
+ ),
+ }
+
+ // free_dlpack_tensor must not free the tensor when deleter is
missing;
+ // reclaim it here to avoid a leak in tests.
+ let _ = Box::from_raw(ptr);
+ }
+ }
+
+ /// free_dlpack_tensor should call the deleter exactly once and return
Ok(()).
+ #[test]
+ fn test_free_dlpack_tensor_calls_deleter() {
+ static mut DELETER_CALLED: bool = false;
+
+ unsafe extern "C" fn test_deleter(_ptr: *mut DLManagedTensor) {
+ // SAFETY: This test is single-threaded; it's safe to mutate the
static flag.
+ unsafe {
+ DELETER_CALLED = true;
+ }
+ }
+
+ let dummy_tensor = DLTensor {
+ data: std::ptr::null_mut(),
+ device: DLDevice {
+ device_type: DLDeviceType::kDLCPU,
+ device_id: 0,
+ },
+ ndim: 0,
+ dtype: DLDataType {
+ code: DL_FLOAT,
+ bits: 64,
+ lanes: 1,
+ },
+ shape: std::ptr::null_mut(),
+ strides: std::ptr::null_mut(),
+ byte_offset: 0,
+ };
+
+ let managed = DLManagedTensor {
+ dl_tensor: dummy_tensor,
+ manager_ctx: std::ptr::null_mut(),
+ deleter: Some(test_deleter),
+ };
+
+ let ptr = Box::into_raw(Box::new(managed));
+
+ unsafe {
+ let result = free_dlpack_tensor(ptr);
+ assert!(
+ result.is_ok(),
+ "free_dlpack_tensor should succeed for valid pointer: {:?}",
+ result
+ );
+ assert!(
+ DELETER_CALLED,
+ "free_dlpack_tensor should invoke the DLPack deleter"
+ );
+
+ // Our custom deleter doesn't free the allocation; reclaim it here.
+ let _ = Box::from_raw(ptr);
+ }
+ }
}