This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new dfe118e18 [QDP] DLPack Tensor unsafe cleanup refactoring (#1011)
dfe118e18 is described below

commit dfe118e18efb7e9c0f9c90ec22b09dd27394ba52
Author: ChenChen Lai <[email protected]>
AuthorDate: Thu Mar 5 10:30:56 2026 +0800

    [QDP] DLPack Tensor unsafe cleanup refactoring (#1011)
    
    * DLPack Tensor Cleanup Refactoring
    
    * add test for free_dlpack_tensor
    
    ---------
    
    Co-authored-by: Ryan Huang <[email protected]>
    Co-authored-by: Guan-Ming (Wesley) Chiu 
<[email protected]>
---
 qdp/qdp-core/examples/dataloader_throughput.rs |  14 ++-
 qdp/qdp-core/examples/nvtx_profile.rs          |  12 +--
 qdp/qdp-core/examples/observability_test.rs    |  11 ++-
 qdp/qdp-core/src/dlpack.rs                     |  43 +++++++-
 qdp/qdp-core/tests/dlpack.rs                   | 132 ++++++++++++++++++++++++-
 5 files changed, 192 insertions(+), 20 deletions(-)

diff --git a/qdp/qdp-core/examples/dataloader_throughput.rs 
b/qdp/qdp-core/examples/dataloader_throughput.rs
index d3cb1ea82..cdbdd21d4 100644
--- a/qdp/qdp-core/examples/dataloader_throughput.rs
+++ b/qdp/qdp-core/examples/dataloader_throughput.rs
@@ -24,6 +24,7 @@ use std::thread;
 use std::time::{Duration, Instant};
 
 use qdp_core::QdpEngine;
+use qdp_core::dlpack::free_dlpack_tensor;
 
 const BATCH_SIZE: usize = 64;
 const VECTOR_LEN: usize = 1024; // 2^10
@@ -99,12 +100,15 @@ fn main() {
         debug_assert_eq!(batch.len() % VECTOR_LEN, 0);
         let num_samples = batch.len() / VECTOR_LEN;
         match engine.encode_batch(&batch, num_samples, VECTOR_LEN, NUM_QUBITS, 
"amplitude") {
-            Ok(ptr) => unsafe {
-                let managed = &mut *ptr;
-                if let Some(deleter) = managed.deleter.take() {
-                    deleter(ptr);
+            Ok(ptr) => {
+                if let Err(e) = unsafe { free_dlpack_tensor(ptr) } {
+                    eprintln!(
+                        "Failed to free DLPack tensor for batch {} (processed 
{} vectors): {:?}",
+                        batch_idx, total_vectors, e
+                    );
+                    return;
                 }
-            },
+            }
             Err(e) => {
                 eprintln!(
                     "Encode batch failed on batch {} (processed {} vectors): 
{:?}",
diff --git a/qdp/qdp-core/examples/nvtx_profile.rs 
b/qdp/qdp-core/examples/nvtx_profile.rs
index 87ceeff2b..d869a6502 100644
--- a/qdp/qdp-core/examples/nvtx_profile.rs
+++ b/qdp/qdp-core/examples/nvtx_profile.rs
@@ -18,6 +18,7 @@
 // Run: cargo run -p qdp-core --example nvtx_profile --features observability 
--release
 
 use qdp_core::QdpEngine;
+use qdp_core::dlpack::free_dlpack_tensor;
 
 fn main() {
     println!("=== NVTX Profiling Example ===");
@@ -61,13 +62,10 @@ fn main() {
             println!("✓ Encoding succeeded");
             println!("✓ DLPack pointer: {:p}", ptr);
 
-            // Clean up
-            unsafe {
-                let managed = &mut *ptr;
-                if let Some(deleter) = managed.deleter.take() {
-                    deleter(ptr);
-                    println!("✓ Memory freed");
-                }
+            // Clean up using shared helper with safety checks
+            match unsafe { free_dlpack_tensor(ptr) } {
+                Ok(()) => println!("✓ Memory freed"),
+                Err(e) => eprintln!("✗ Failed to free DLPack tensor: {:?}", e),
             }
         }
         Err(e) => {
diff --git a/qdp/qdp-core/examples/observability_test.rs 
b/qdp/qdp-core/examples/observability_test.rs
index 462e8aef7..5af7d7c1c 100644
--- a/qdp/qdp-core/examples/observability_test.rs
+++ b/qdp/qdp-core/examples/observability_test.rs
@@ -19,6 +19,7 @@
 // Run: cargo run -p qdp-core --example observability_test --release
 
 use qdp_core::QdpEngine;
+use qdp_core::dlpack::free_dlpack_tensor;
 use std::env;
 
 fn main() {
@@ -92,12 +93,12 @@ fn main() {
     for i in 0..NUM_SAMPLES {
         let sample = &test_data[i * VECTOR_LEN..(i + 1) * VECTOR_LEN];
         match engine.encode(sample, NUM_QUBITS, "amplitude") {
-            Ok(ptr) => unsafe {
-                let managed = &mut *ptr;
-                if let Some(deleter) = managed.deleter.take() {
-                    deleter(ptr);
+            Ok(ptr) => {
+                if let Err(e) = unsafe { free_dlpack_tensor(ptr) } {
+                    eprintln!("✗ Failed to free DLPack tensor for sample {}: 
{:?}", i, e);
+                    return;
                 }
-            },
+            }
             Err(e) => {
                 eprintln!("✗ Encoding failed for sample {}: {:?}", i, e);
                 return;
diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index 1780fa404..1181dbd8e 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -16,9 +16,9 @@
 
 // DLPack protocol for zero-copy GPU memory sharing with PyTorch
 
-use crate::error::Result;
 #[cfg(target_os = "linux")]
-use crate::error::{MahoutError, cuda_error_to_string};
+use crate::error::cuda_error_to_string;
+use crate::error::{MahoutError, Result};
 use crate::gpu::memory::{BufferStorage, GpuStateVector, Precision};
 use std::os::raw::{c_int, c_void};
 use std::sync::Arc;
@@ -205,6 +205,45 @@ pub unsafe extern "C" fn dlpack_deleter(managed: *mut 
DLManagedTensor) {
     let _ = Box::from_raw(managed);
 }
 
+/// Safely free a `DLManagedTensor` pointer returned from encoding APIs.
+///
+/// This helper function centralizes the unsafe pointer dereference and deleter
+/// invocation logic, adding safety checks to prevent common errors like null
+/// pointer dereference and double-free.
+///
+/// # Safety
+/// The caller must ensure:
+/// - `ptr` is a valid `DLManagedTensor` pointer returned from 
`QdpEngine::encode()`
+///   or similar methods, or is null
+/// - The pointer has not been freed before (either by calling this function
+///   or by PyTorch's DLPack deleter)
+/// - The pointer is not used after this call
+///
+/// # Errors
+/// Returns `Err` if:
+/// - The pointer is null
+/// - The deleter is missing or has already been called
+#[allow(unsafe_op_in_unsafe_fn)]
+pub unsafe fn free_dlpack_tensor(ptr: *mut DLManagedTensor) -> Result<()> {
+    if ptr.is_null() {
+        return Err(MahoutError::InvalidInput(
+            "DLPack pointer is null (nothing to free)".into(),
+        ));
+    }
+
+    // SAFETY: Caller guarantees ptr is valid and not yet freed.
+    // We've checked it's not null above.
+    let managed = &mut *ptr;
+
+    let deleter = managed.deleter.take().ok_or_else(|| {
+        MahoutError::InvalidInput("DLPack deleter missing or already 
called".into())
+    })?;
+
+    // Call the DLPack deleter to free memory
+    deleter(ptr);
+    Ok(())
+}
+
 impl GpuStateVector {
     /// Convert to DLPack format for PyTorch
     ///
diff --git a/qdp/qdp-core/tests/dlpack.rs b/qdp/qdp-core/tests/dlpack.rs
index c22dda384..29cfc74ce 100644
--- a/qdp/qdp-core/tests/dlpack.rs
+++ b/qdp/qdp-core/tests/dlpack.rs
@@ -21,8 +21,12 @@ mod dlpack_tests {
     use std::ffi::c_void;
 
     use cudarc::driver::CudaDevice;
+    use qdp_core::MahoutError;
     use qdp_core::Precision;
-    use qdp_core::dlpack::{CUDA_STREAM_LEGACY, synchronize_stream};
+    use qdp_core::dlpack::{
+        CUDA_STREAM_LEGACY, DL_FLOAT, DLDataType, DLDevice, DLDeviceType, 
DLManagedTensor,
+        DLTensor, free_dlpack_tensor, synchronize_stream,
+    };
     use qdp_core::gpu::memory::GpuStateVector;
 
     #[test]
@@ -180,4 +184,130 @@ mod dlpack_tests {
             );
         }
     }
+
+    /// free_dlpack_tensor(null) should return an InvalidInput error instead 
of panicking.
+    #[test]
+    fn test_free_dlpack_tensor_null_ptr() {
+        unsafe {
+            let result = free_dlpack_tensor(std::ptr::null_mut());
+            match result {
+                Err(MahoutError::InvalidInput(msg)) => {
+                    assert!(
+                        msg.to_lowercase().contains("null"),
+                        "Expected null-pointer error message, got: {}",
+                        msg
+                    );
+                }
+                other => panic!(
+                    "Expected InvalidInput error for null pointer, got: {:?}",
+                    other
+                ),
+            }
+        }
+    }
+
+    /// free_dlpack_tensor should detect missing deleter and return an 
InvalidInput error.
+    #[test]
+    fn test_free_dlpack_tensor_missing_deleter() {
+        // Minimal, but structurally valid, DLTensor for constructing 
DLManagedTensor.
+        let dummy_tensor = DLTensor {
+            data: std::ptr::null_mut(),
+            device: DLDevice {
+                device_type: DLDeviceType::kDLCPU,
+                device_id: 0,
+            },
+            ndim: 0,
+            dtype: DLDataType {
+                code: DL_FLOAT,
+                bits: 64,
+                lanes: 1,
+            },
+            shape: std::ptr::null_mut(),
+            strides: std::ptr::null_mut(),
+            byte_offset: 0,
+        };
+
+        let managed = DLManagedTensor {
+            dl_tensor: dummy_tensor,
+            manager_ctx: std::ptr::null_mut(),
+            deleter: None,
+        };
+
+        let ptr = Box::into_raw(Box::new(managed));
+
+        unsafe {
+            let result = free_dlpack_tensor(ptr);
+            match result {
+                Err(MahoutError::InvalidInput(msg)) => {
+                    assert!(
+                        msg.to_lowercase().contains("deleter"),
+                        "Expected missing-deleter error message, got: {}",
+                        msg
+                    );
+                }
+                other => panic!(
+                    "Expected InvalidInput error for missing deleter, got: 
{:?}",
+                    other
+                ),
+            }
+
+            // free_dlpack_tensor must not free the tensor when deleter is 
missing;
+            // reclaim it here to avoid a leak in tests.
+            let _ = Box::from_raw(ptr);
+        }
+    }
+
+    /// free_dlpack_tensor should call the deleter exactly once and return 
Ok(()).
+    #[test]
+    fn test_free_dlpack_tensor_calls_deleter() {
+        static mut DELETER_CALLED: bool = false;
+
+        unsafe extern "C" fn test_deleter(_ptr: *mut DLManagedTensor) {
+            // SAFETY: This test is single-threaded; it's safe to mutate the 
static flag.
+            unsafe {
+                DELETER_CALLED = true;
+            }
+        }
+
+        let dummy_tensor = DLTensor {
+            data: std::ptr::null_mut(),
+            device: DLDevice {
+                device_type: DLDeviceType::kDLCPU,
+                device_id: 0,
+            },
+            ndim: 0,
+            dtype: DLDataType {
+                code: DL_FLOAT,
+                bits: 64,
+                lanes: 1,
+            },
+            shape: std::ptr::null_mut(),
+            strides: std::ptr::null_mut(),
+            byte_offset: 0,
+        };
+
+        let managed = DLManagedTensor {
+            dl_tensor: dummy_tensor,
+            manager_ctx: std::ptr::null_mut(),
+            deleter: Some(test_deleter),
+        };
+
+        let ptr = Box::into_raw(Box::new(managed));
+
+        unsafe {
+            let result = free_dlpack_tensor(ptr);
+            assert!(
+                result.is_ok(),
+                "free_dlpack_tensor should succeed for valid pointer: {:?}",
+                result
+            );
+            assert!(
+                DELETER_CALLED,
+                "free_dlpack_tensor should invoke the DLPack deleter"
+            );
+
+            // Our custom deleter doesn't free the allocation; reclaim it here.
+            let _ = Box::from_raw(ptr);
+        }
+    }
 }

Reply via email to