CheyuWu commented on code in PR #934:
URL: https://github.com/apache/mahout/pull/934#discussion_r2728286216


##########
qdp/qdp-core/src/lib.rs:
##########
@@ -313,108 +313,168 @@ impl QdpEngine {
     ///
     /// TODO: Refactor to use QuantumEncoder trait (add `encode_from_gpu_ptr` 
to trait)
     /// to reduce duplication with AmplitudeEncoder::encode(). This would also 
make it
-    /// easier to add GPU pointer support for other encoders (angle, basis) in 
the future.
+    /// easier to add GPU pointer support for other encoders (angle) in the 
future.
     ///
     /// # Arguments
-    /// * `input_d` - Device pointer to input data (f64 array on GPU)
-    /// * `input_len` - Number of f64 elements in the input
+    /// * `input_d` - Device pointer to input data (f64 for amplitude, 
usize/int64 for basis)
+    /// * `input_len` - Number of elements in the input
     /// * `num_qubits` - Number of qubits for encoding
-    /// * `encoding_method` - Strategy (currently only "amplitude" supported)
+    /// * `encoding_method` - Strategy ("amplitude" or "basis")
     ///
     /// # Returns
     /// DLPack pointer for zero-copy PyTorch integration
     ///
     /// # Safety
     /// The input pointer must:
     /// - Point to valid GPU memory on the same device as the engine
-    /// - Contain at least `input_len` f64 elements
+    /// - Contain at least `input_len` elements of the expected dtype
     /// - Remain valid for the duration of this call
     #[cfg(target_os = "linux")]
     pub unsafe fn encode_from_gpu_ptr(
         &self,
-        input_d: *const f64,
+        input_d: *const std::ffi::c_void,
         input_len: usize,
         num_qubits: usize,
         encoding_method: &str,
     ) -> Result<*mut DLManagedTensor> {
         crate::profile_scope!("Mahout::EncodeFromGpuPtr");
 
-        if encoding_method != "amplitude" {
-            return Err(MahoutError::NotImplemented(format!(
-                "GPU pointer encoding currently only supports 'amplitude' 
method, got '{}'",
-                encoding_method
-            )));
-        }
-
-        if input_len == 0 {
-            return Err(MahoutError::InvalidInput(
-                "Input data cannot be empty".into(),
-            ));
-        }
-
         let state_len = 1usize << num_qubits;
-        if input_len > state_len {
-            return Err(MahoutError::InvalidInput(format!(
-                "Input size {} exceeds state vector size {} (2^{} qubits)",
-                input_len, state_len, num_qubits
-            )));
-        }
-
-        // Allocate output state vector
-        let state_vector = {
-            crate::profile_scope!("GPU::Alloc");
-            gpu::GpuStateVector::new(&self.device, num_qubits)?
-        };
-
-        // Compute inverse L2 norm on GPU
-        let inv_norm = {
-            crate::profile_scope!("GPU::NormFromPtr");
-            // SAFETY: input_d validity is guaranteed by the caller's safety 
contract
-            unsafe {
-                gpu::AmplitudeEncoder::calculate_inv_norm_gpu(&self.device, 
input_d, input_len)?
+        match encoding_method {
+            "amplitude" => {
+                if input_len == 0 {
+                    return Err(MahoutError::InvalidInput(
+                        "Input data cannot be empty".into(),
+                    ));
+                }
+
+                if input_len > state_len {
+                    return Err(MahoutError::InvalidInput(format!(
+                        "Input size {} exceeds state vector size {} (2^{} 
qubits)",
+                        input_len, state_len, num_qubits
+                    )));
+                }
+
+                let input_d = input_d as *const f64;
+
+                // Allocate output state vector
+                let state_vector = {
+                    crate::profile_scope!("GPU::Alloc");
+                    gpu::GpuStateVector::new(&self.device, num_qubits)?
+                };
+
+                // Compute inverse L2 norm on GPU
+                let inv_norm = {
+                    crate::profile_scope!("GPU::NormFromPtr");
+                    // SAFETY: input_d validity is guaranteed by the caller's 
safety contract
+                    unsafe {
+                        gpu::AmplitudeEncoder::calculate_inv_norm_gpu(
+                            &self.device,
+                            input_d,
+                            input_len,
+                        )?
+                    }
+                };
+
+                // Get output pointer
+                let state_ptr = state_vector.ptr_f64().ok_or_else(|| {
+                    MahoutError::InvalidInput(
+                        "State vector precision mismatch (expected float64 
buffer)".to_string(),
+                    )
+                })?;
+
+                // Launch encoding kernel
+                {
+                    crate::profile_scope!("GPU::KernelLaunch");
+                    let ret = unsafe {
+                        qdp_kernels::launch_amplitude_encode(
+                            input_d,
+                            state_ptr as *mut std::ffi::c_void,
+                            input_len,
+                            state_len,
+                            inv_norm,
+                            std::ptr::null_mut(), // default stream
+                        )
+                    };
+
+                    if ret != 0 {
+                        return Err(MahoutError::KernelLaunch(format!(
+                            "Amplitude encode kernel failed with CUDA error 
code: {} ({})",
+                            ret,
+                            cuda_error_to_string(ret)
+                        )));
+                    }
+                }
+
+                // Synchronize
+                {
+                    crate::profile_scope!("GPU::Synchronize");
+                    self.device.synchronize().map_err(|e| {
+                        MahoutError::Cuda(format!("CUDA device synchronize 
failed: {:?}", e))
+                    })?;
+                }
+
+                let state_vector = state_vector.to_precision(&self.device, 
self.precision)?;
+                Ok(state_vector.to_dlpack())
             }
-        };
-
-        // Get output pointer
-        let state_ptr = state_vector.ptr_f64().ok_or_else(|| {
-            MahoutError::InvalidInput(
-                "State vector precision mismatch (expected float64 
buffer)".to_string(),
-            )
-        })?;
-
-        // Launch encoding kernel
-        {
-            crate::profile_scope!("GPU::KernelLaunch");
-            let ret = unsafe {
-                qdp_kernels::launch_amplitude_encode(
-                    input_d,
-                    state_ptr as *mut std::ffi::c_void,
-                    input_len,
-                    state_len,
-                    inv_norm,
-                    std::ptr::null_mut(), // default stream
-                )
-            };
-
-            if ret != 0 {
-                return Err(MahoutError::KernelLaunch(format!(
-                    "Amplitude encode kernel failed with CUDA error code: {} 
({})",
-                    ret,
-                    cuda_error_to_string(ret)
-                )));
+            "basis" => {
+                if input_len != 1 {
+                    return Err(MahoutError::InvalidInput(format!(
+                        "Basis encoding expects exactly 1 value (the basis 
index), got {}",
+                        input_len
+                    )));
+                }
+
+                let basis_indices_d = input_d as *const usize;

Review Comment:
   I encountered an issue where directly casting the input from `f64` to 
`usize` results in a loss of precision due to truncation.
   
   I’m not sure whether using `std::ffi::c_void` in this case could introduce 
additional issues, especially around type safety or incorrect casting.
   
   And my solution is writing a new cuda function to address this 
https://github.com/apache/mahout/pull/937/files#diff-22afb0c3753a9d7765229f78aa1e96463b958cd6bab724527dc540a17e28592aR47
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to