This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git

commit cf53ba17537039bc714abcb92af436efa1c45a54
Author: Ping <[email protected]>
AuthorDate: Tue Dec 16 11:31:01 2025 +0800

    [QDP] Set float32 to Default and leave float64 as Optional (#712)
    
    * Set float32 to default
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix errors
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix python bindings
    
    Signed-off-by: 400Ping <[email protected]>
    
    * [chore] update log
    
    Signed-off-by: 400Ping <[email protected]>
    
    ---------
    
    Signed-off-by: 400Ping <[email protected]>
---
 qdp/qdp-core/src/dlpack.rs                  |  15 ++-
 qdp/qdp-core/src/gpu/encodings/amplitude.rs |  23 ++++-
 qdp/qdp-core/src/gpu/memory.rs              | 142 +++++++++++++++++++++++++---
 qdp/qdp-core/src/lib.rs                     |  13 ++-
 qdp/qdp-kernels/src/amplitude.cu            |  36 +++++++
 qdp/qdp-kernels/src/lib.rs                  |  39 ++++++++
 qdp/qdp-python/README.md                    |   5 +-
 qdp/qdp-python/src/lib.rs                   |  20 +++-
 qdp/qdp-python/tests/test_bindings.py       |  17 +++-
 qdp/qdp-python/tests/test_high_fidelity.py  |  20 +++-
 10 files changed, 296 insertions(+), 34 deletions(-)

diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index 42d7eaf62..883d19b37 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -18,7 +18,7 @@
 
 use std::os::raw::{c_int, c_void};
 use std::sync::Arc;
-use crate::gpu::memory::GpuStateVector;
+use crate::gpu::memory::{BufferStorage, GpuStateVector, Precision};
 
 // DLPack C structures (matching dlpack/dlpack.h)
 
@@ -104,7 +104,7 @@ pub unsafe extern "C" fn dlpack_deleter(managed: *mut 
DLManagedTensor) {
     // 3. Free GPU buffer (Arc reference count)
     let ctx = (*managed).manager_ctx;
     if !ctx.is_null() {
-        let _ = Arc::from_raw(ctx as *const crate::gpu::memory::GpuBufferRaw);
+        let _ = Arc::from_raw(ctx as *const BufferStorage);
     }
 
     // 4. Free DLManagedTensor
@@ -131,16 +131,21 @@ impl GpuStateVector {
         // Increment Arc ref count (decremented in deleter)
         let ctx = Arc::into_raw(self.buffer.clone()) as *mut c_void;
 
+        let dtype_bits = match self.precision() {
+            Precision::Float32 => 64, // complex64 (2x float32)
+            Precision::Float64 => 128, // complex128 (2x float64)
+        };
+
         let tensor = DLTensor {
-            data: self.ptr() as *mut c_void,
+            data: self.ptr_void(),
             device: DLDevice {
                 device_type: DLDeviceType::kDLCUDA,
                 device_id: 0,
             },
             ndim: 1,
             dtype: DLDataType {
-                code: DL_COMPLEX,  // Complex128
-                bits: 128,         // 2 * 64-bit floats
+                code: DL_COMPLEX,
+                bits: dtype_bits,
                 lanes: 1,
             },
             shape: shape_ptr,
diff --git a/qdp/qdp-core/src/gpu/encodings/amplitude.rs 
b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
index 844f6d1e8..715f0984e 100644
--- a/qdp/qdp-core/src/gpu/encodings/amplitude.rs
+++ b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
@@ -99,12 +99,20 @@ impl QuantumEncoder for AmplitudeEncoder {
                     1.0 / norm
                 };
 
+                let state_ptr = state_vector.ptr_f64().ok_or_else(|| {
+                    let actual = state_vector.precision();
+                    MahoutError::InvalidInput(format!(
+                        "State vector precision mismatch (expected float64 
buffer, got {:?})",
+                        actual
+                    ))
+                })?;
+
                 let ret = {
                     crate::profile_scope!("GPU::KernelLaunch");
                     unsafe {
                         launch_amplitude_encode(
                             *input_slice.device_ptr() as *const f64,
-                            state_vector.ptr() as *mut c_void,
+                            state_ptr as *mut c_void,
                             host_data.len(),
                             state_len,
                             inv_norm,
@@ -227,10 +235,13 @@ impl QuantumEncoder for AmplitudeEncoder {
         // Launch batch kernel
         {
             crate::profile_scope!("GPU::BatchKernelLaunch");
+            let state_ptr = batch_state_vector.ptr_f64().ok_or_else(|| 
MahoutError::InvalidInput(
+                "Batch state vector precision mismatch (expected float64 
buffer)".to_string()
+            ))?;
             let ret = unsafe {
                 launch_amplitude_encode_batch(
                     *input_batch_gpu.device_ptr() as *const f64,
-                    batch_state_vector.ptr() as *mut c_void,
+                    state_ptr as *mut c_void,
                     *inv_norms_gpu.device_ptr() as *const f64,
                     num_samples,
                     sample_size,
@@ -283,13 +294,17 @@ impl AmplitudeEncoder {
         inv_norm: f64,
         state_vector: &GpuStateVector,
     ) -> Result<()> {
+        let base_state_ptr = state_vector.ptr_f64().ok_or_else(|| 
MahoutError::InvalidInput(
+            "State vector precision mismatch (expected float64 
buffer)".to_string()
+        ))?;
+
         // Use generic pipeline infrastructure
         // The closure handles amplitude-specific kernel launch logic
         run_dual_stream_pipeline(device, host_data, |stream, input_ptr, 
chunk_offset, chunk_len| {
             // Calculate offset pointer for state vector (type-safe pointer 
arithmetic)
             // Offset is in complex numbers (CuDoubleComplex), not f64 elements
             let state_ptr_offset = unsafe {
-                state_vector.ptr().cast::<u8>()
+                base_state_ptr.cast::<u8>()
                     .add(chunk_offset * 
std::mem::size_of::<qdp_kernels::CuDoubleComplex>())
                     .cast::<std::ffi::c_void>()
             };
@@ -338,7 +353,7 @@ impl AmplitudeEncoder {
 
             // Calculate tail pointer (in complex numbers)
             let tail_ptr = unsafe {
-                state_vector.ptr().add(padding_start) as *mut c_void
+                base_state_ptr.add(padding_start) as *mut c_void
             };
 
             // Zero-fill padding region using CUDA Runtime API
diff --git a/qdp/qdp-core/src/gpu/memory.rs b/qdp/qdp-core/src/gpu/memory.rs
index a333e103d..72bf5bc7f 100644
--- a/qdp/qdp-core/src/gpu/memory.rs
+++ b/qdp/qdp-core/src/gpu/memory.rs
@@ -13,12 +13,19 @@
 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 // See the License for the specific language governing permissions and
 // limitations under the License.
-
+use std::ffi::c_void;
 use std::sync::Arc;
 use cudarc::driver::{CudaDevice, CudaSlice, DevicePtr};
-use qdp_kernels::CuDoubleComplex;
+use qdp_kernels::{CuComplex, CuDoubleComplex};
 use crate::error::{MahoutError, Result};
 
+/// Precision of the GPU state vector.
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum Precision {
+    Float32,
+    Float64,
+}
+
 #[cfg(target_os = "linux")]
 fn bytes_to_mib(bytes: usize) -> f64 {
     bytes as f64 / (1024.0 * 1024.0)
@@ -130,28 +137,58 @@ pub(crate) fn map_allocation_error(
 
 /// RAII wrapper for GPU memory buffer
 /// Automatically frees GPU memory when dropped
-pub struct GpuBufferRaw {
-    pub(crate) slice: CudaSlice<CuDoubleComplex>,
+pub struct GpuBufferRaw<T> {
+    pub(crate) slice: CudaSlice<T>,
 }
 
-impl GpuBufferRaw {
+impl<T> GpuBufferRaw<T> {
     /// Get raw pointer to GPU memory
     ///
     /// # Safety
     /// Valid only while GpuBufferRaw is alive
-    pub fn ptr(&self) -> *mut CuDoubleComplex {
-        *self.slice.device_ptr() as *mut CuDoubleComplex
+    pub fn ptr(&self) -> *mut T {
+        *self.slice.device_ptr() as *mut T
+    }
+}
+
+/// Storage wrapper for precision-specific GPU buffers
+pub enum BufferStorage {
+    F32(GpuBufferRaw<CuComplex>),
+    F64(GpuBufferRaw<CuDoubleComplex>),
+}
+
+impl BufferStorage {
+    fn precision(&self) -> Precision {
+        match self {
+            BufferStorage::F32(_) => Precision::Float32,
+            BufferStorage::F64(_) => Precision::Float64,
+        }
+    }
+
+    fn ptr_void(&self) -> *mut c_void {
+        match self {
+            BufferStorage::F32(buf) => buf.ptr() as *mut c_void,
+            BufferStorage::F64(buf) => buf.ptr() as *mut c_void,
+        }
+    }
+
+    fn ptr_f64(&self) -> Option<*mut CuDoubleComplex> {
+        match self {
+            BufferStorage::F64(buf) => Some(buf.ptr()),
+            _ => None,
+        }
     }
 }
 
 /// Quantum state vector on GPU
 ///
-/// Manages complex128 array of size 2^n (n = qubits) in GPU memory.
+/// Manages complex array of size 2^n (n = qubits) in GPU memory.
 /// Uses Arc for shared ownership (needed for DLPack/PyTorch integration).
 /// Thread-safe: Send + Sync
+#[derive(Clone)]
 pub struct GpuStateVector {
     // Use Arc to allow DLPack to share ownership
-    pub(crate) buffer: Arc<GpuBufferRaw>,
+    pub(crate) buffer: Arc<BufferStorage>,
     pub num_qubits: usize,
     pub size_elements: usize,
 }
@@ -190,7 +227,7 @@ impl GpuStateVector {
             ))?;
 
             Ok(Self {
-                buffer: Arc::new(GpuBufferRaw { slice }),
+                buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
                 num_qubits: qubits,
                 size_elements: _size_elements,
             })
@@ -203,12 +240,22 @@ impl GpuStateVector {
         }
     }
 
+    /// Get current precision of the underlying buffer.
+    pub fn precision(&self) -> Precision {
+        self.buffer.precision()
+    }
+
     /// Get raw GPU pointer for DLPack/FFI
     ///
     /// # Safety
     /// Valid while GpuStateVector or any Arc clone is alive
-    pub fn ptr(&self) -> *mut CuDoubleComplex {
-        self.buffer.ptr()
+    pub fn ptr_void(&self) -> *mut c_void {
+        self.buffer.ptr_void()
+    }
+
+    /// Returns a double-precision pointer if the buffer stores complex128 
data.
+    pub fn ptr_f64(&self) -> Option<*mut CuDoubleComplex> {
+        self.buffer.ptr_f64()
     }
 
     /// Get the number of qubits
@@ -251,7 +298,7 @@ impl GpuStateVector {
             ))?;
 
             Ok(Self {
-                buffer: Arc::new(GpuBufferRaw { slice }),
+                buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
                 num_qubits: qubits,
                 size_elements: total_elements,
             })
@@ -262,4 +309,73 @@ impl GpuStateVector {
             Err(MahoutError::Cuda("CUDA is only available on Linux. This build 
does not support GPU operations.".to_string()))
         }
     }
+
+    /// Convert the state vector to the requested precision (GPU-side).
+    ///
+    /// For now only down-conversion from Float64 -> Float32 is supported.
+    pub fn to_precision(&self, device: &Arc<CudaDevice>, target: Precision) -> 
Result<Self> {
+        if self.precision() == target {
+            return Ok(self.clone());
+        }
+
+        match (self.precision(), target) {
+            (Precision::Float64, Precision::Float32) => {
+                #[cfg(target_os = "linux")]
+                {
+                    let requested_bytes = self.size_elements
+                        .checked_mul(std::mem::size_of::<CuComplex>())
+                        .ok_or_else(|| MahoutError::MemoryAllocation(
+                            format!("Requested GPU allocation size overflow 
(elements={})", self.size_elements)
+                        ))?;
+
+                    ensure_device_memory_available(requested_bytes, "state 
vector precision conversion", Some(self.num_qubits))?;
+
+                    let slice = unsafe {
+                        device.alloc::<CuComplex>(self.size_elements)
+                    }.map_err(|e| map_allocation_error(
+                        requested_bytes,
+                        "state vector precision conversion",
+                        Some(self.num_qubits),
+                        e,
+                    ))?;
+
+                    let src_ptr = self.ptr_f64().ok_or_else(|| 
MahoutError::InvalidInput(
+                        "Source state vector is not Float64; cannot convert to 
Float32".to_string()
+                    ))?;
+
+                    let ret = unsafe {
+                        qdp_kernels::convert_state_to_float(
+                            src_ptr as *const CuDoubleComplex,
+                            *slice.device_ptr() as *mut CuComplex,
+                            self.size_elements,
+                            std::ptr::null_mut(),
+                        )
+                    };
+
+                    if ret != 0 {
+                        return Err(MahoutError::KernelLaunch(
+                            format!("Precision conversion kernel failed: {}", 
ret)
+                        ));
+                    }
+
+                    device.synchronize()
+                        .map_err(|e| MahoutError::Cuda(format!("Failed to sync 
after precision conversion: {:?}", e)))?;
+
+                    Ok(Self {
+                        buffer: Arc::new(BufferStorage::F32(GpuBufferRaw { 
slice })),
+                        num_qubits: self.num_qubits,
+                        size_elements: self.size_elements,
+                    })
+                }
+
+                #[cfg(not(target_os = "linux"))]
+                {
+                    Err(MahoutError::Cuda("Precision conversion requires CUDA 
(Linux)".to_string()))
+                }
+            }
+            _ => Err(MahoutError::NotImplemented(
+                "Requested precision conversion is not supported".to_string()
+            )),
+        }
+    }
 }
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index 2f8f42092..e14e2bab0 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -19,11 +19,11 @@ pub mod gpu;
 pub mod error;
 pub mod preprocessing;
 pub mod io;
-
 #[macro_use]
 mod profiling;
 
 pub use error::{MahoutError, Result};
+pub use gpu::memory::Precision;
 
 use std::sync::Arc;
 
@@ -37,6 +37,7 @@ use crate::gpu::get_encoder;
 /// Provides unified interface for device management, memory allocation, and 
DLPack.
 pub struct QdpEngine {
     device: Arc<CudaDevice>,
+    precision: Precision,
 }
 
 impl QdpEngine {
@@ -45,10 +46,16 @@ impl QdpEngine {
     /// # Arguments
     /// * `device_id` - CUDA device ID (typically 0)
     pub fn new(device_id: usize) -> Result<Self> {
+        Self::new_with_precision(device_id, Precision::Float32)
+    }
+
+    /// Initialize engine with explicit precision.
+    pub fn new_with_precision(device_id: usize, precision: Precision) -> 
Result<Self> {
         let device = CudaDevice::new(device_id)
             .map_err(|e| MahoutError::Cuda(format!("Failed to initialize CUDA 
device {}: {:?}", device_id, e)))?;
         Ok(Self {
-            device  // CudaDevice::new already returns Arc<CudaDevice> in 
cudarc 0.11
+            device,  // CudaDevice::new already returns Arc<CudaDevice> in 
cudarc 0.11
+            precision,
         })
     }
 
@@ -76,6 +83,7 @@ impl QdpEngine {
 
         let encoder = get_encoder(encoding_method)?;
         let state_vector = encoder.encode(&self.device, data, num_qubits)?;
+        let state_vector = state_vector.to_precision(&self.device, 
self.precision)?;
         let dlpack_ptr = {
             crate::profile_scope!("DLPack::Wrap");
             state_vector.to_dlpack()
@@ -121,6 +129,7 @@ impl QdpEngine {
             num_qubits,
         )?;
 
+        let state_vector = state_vector.to_precision(&self.device, 
self.precision)?;
         let dlpack_ptr = state_vector.to_dlpack();
         Ok(dlpack_ptr)
     }
diff --git a/qdp/qdp-kernels/src/amplitude.cu b/qdp/qdp-kernels/src/amplitude.cu
index 0e679e2b9..ea5fc27f7 100644
--- a/qdp/qdp-kernels/src/amplitude.cu
+++ b/qdp/qdp-kernels/src/amplitude.cu
@@ -454,6 +454,42 @@ int launch_l2_norm_batch(
     return (int)cudaGetLastError();
 }
 
+/// Kernel: convert complex128 state vector to complex64.
+__global__ void convert_state_to_complex64_kernel(
+    const cuDoubleComplex* __restrict__ input_state,
+    cuComplex* __restrict__ output_state,
+    size_t len
+) {
+    const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx >= len) return;
+
+    const cuDoubleComplex v = input_state[idx];
+    output_state[idx] = make_cuComplex((float)v.x, (float)v.y);
+}
+
+/// Launch conversion kernel from complex128 to complex64.
+int convert_state_to_float(
+    const cuDoubleComplex* input_state_d,
+    cuComplex* output_state_d,
+    size_t len,
+    cudaStream_t stream
+) {
+    if (len == 0) {
+        return cudaErrorInvalidValue;
+    }
+
+    const int blockSize = 256;
+    const int gridSize = (int)((len + blockSize - 1) / blockSize);
+
+    convert_state_to_complex64_kernel<<<gridSize, blockSize, 0, stream>>>(
+        input_state_d,
+        output_state_d,
+        len
+    );
+
+    return (int)cudaGetLastError();
+}
+
 // TODO: Future encoding methods:
 // - launch_angle_encode (angle encoding)
 // - launch_basis_encode (basis encoding)
diff --git a/qdp/qdp-kernels/src/lib.rs b/qdp/qdp-kernels/src/lib.rs
index f2ce2ad29..bae8782ef 100644
--- a/qdp/qdp-kernels/src/lib.rs
+++ b/qdp/qdp-kernels/src/lib.rs
@@ -36,6 +36,22 @@ unsafe impl cudarc::driver::DeviceRepr for CuDoubleComplex {}
 #[cfg(target_os = "linux")]
 unsafe impl cudarc::driver::ValidAsZeroBits for CuDoubleComplex {}
 
+// Complex number (matches CUDA's cuComplex / cuFloatComplex)
+#[repr(C)]
+#[derive(Debug, Clone, Copy)]
+pub struct CuComplex {
+    pub x: f32,  // Real part
+    pub y: f32,  // Imaginary part
+}
+
+// Implement DeviceRepr for cudarc compatibility
+#[cfg(target_os = "linux")]
+unsafe impl cudarc::driver::DeviceRepr for CuComplex {}
+
+// Also implement ValidAsZeroBits for alloc_zeros support
+#[cfg(target_os = "linux")]
+unsafe impl cudarc::driver::ValidAsZeroBits for CuComplex {}
+
 // CUDA kernel FFI (Linux only, dummy on other platforms)
 #[cfg(target_os = "linux")]
 unsafe extern "C" {
@@ -93,6 +109,18 @@ unsafe extern "C" {
         stream: *mut c_void,
     ) -> i32;
 
+    /// Convert a complex128 state vector to complex64 on GPU.
+    /// Returns CUDA error code (0 = success).
+    ///
+    /// # Safety
+    /// Pointers must reference valid device memory on the provided stream.
+    pub fn convert_state_to_float(
+        input_state_d: *const CuDoubleComplex,
+        output_state_d: *mut CuComplex,
+        len: usize,
+        stream: *mut c_void,
+    ) -> i32;
+
     // TODO: launch_angle_encode, launch_basis_encode
 }
 
@@ -132,3 +160,14 @@ pub extern "C" fn launch_l2_norm_batch(
 ) -> i32 {
     999
 }
+
+#[cfg(not(target_os = "linux"))]
+#[unsafe(no_mangle)]
+pub extern "C" fn convert_state_to_float(
+    _input_state_d: *const CuDoubleComplex,
+    _output_state_d: *mut CuComplex,
+    _len: usize,
+    _stream: *mut c_void,
+) -> i32 {
+    999
+}
diff --git a/qdp/qdp-python/README.md b/qdp/qdp-python/README.md
index aed316952..98d2b0106 100644
--- a/qdp/qdp-python/README.md
+++ b/qdp/qdp-python/README.md
@@ -7,9 +7,12 @@ PyO3 Python bindings for Apache Mahout QDP.
 ```python
 from mahout_qdp import QdpEngine
 
-# Initialize on GPU 0
+# Initialize on GPU 0 (defaults to float32 output)
 engine = QdpEngine(0)
 
+# Optional: request float64 output if you need higher precision
+# engine = QdpEngine(0, precision="float64")
+
 # Encode data
 data = [0.5, 0.5, 0.5, 0.5]
 dlpack_ptr = engine.encode(data, num_qubits=2, encoding_method="amplitude")
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 340aae814..04f3c5367 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -17,7 +17,7 @@
 use pyo3::prelude::*;
 use pyo3::exceptions::PyRuntimeError;
 use pyo3::ffi;
-use qdp_core::QdpEngine as CoreEngine;
+use qdp_core::{Precision, QdpEngine as CoreEngine};
 use qdp_core::dlpack::DLManagedTensor;
 
 /// Quantum tensor wrapper implementing DLPack protocol
@@ -139,6 +139,7 @@ impl QdpEngine {
     ///
     /// Args:
     ///     device_id: CUDA device ID (typically 0)
+    ///     precision: Output precision ("float32" default, or "float64")
     ///
     /// Returns:
     ///     QdpEngine instance
@@ -146,9 +147,20 @@ impl QdpEngine {
     /// Raises:
     ///     RuntimeError: If CUDA device initialization fails
     #[new]
-    #[pyo3(signature = (device_id=0))]
-    fn new(device_id: usize) -> PyResult<Self> {
-        let engine = CoreEngine::new(device_id)
+    #[pyo3(signature = (device_id=0, precision="float32"))]
+    fn new(device_id: usize, precision: &str) -> PyResult<Self> {
+        let precision = match precision.to_ascii_lowercase().as_str() {
+            "float32" | "f32" | "float" => Precision::Float32,
+            "float64" | "f64" | "double" => Precision::Float64,
+            other => {
+                return Err(PyRuntimeError::new_err(format!(
+                    "Unsupported precision '{}'. Use 'float32' (default) or 
'float64'.",
+                    other
+                )))
+            }
+        };
+
+        let engine = CoreEngine::new_with_precision(device_id, precision)
             .map_err(|e| PyRuntimeError::new_err(format!("Failed to 
initialize: {}", e)))?;
         Ok(Self { engine })
     }
diff --git a/qdp/qdp-python/tests/test_bindings.py 
b/qdp/qdp-python/tests/test_bindings.py
index 0f7866299..1fc586f78 100644
--- a/qdp/qdp-python/tests/test_bindings.py
+++ b/qdp/qdp-python/tests/test_bindings.py
@@ -85,7 +85,22 @@ def test_pytorch_integration():
     torch_tensor = torch.from_dlpack(qtensor)
     assert torch_tensor.is_cuda
     assert torch_tensor.device.index == 0
-    assert torch_tensor.dtype == torch.complex128
+    assert torch_tensor.dtype == torch.complex64
 
     # Verify shape (2 qubits = 2^2 = 4 elements)
     assert torch_tensor.shape == (4,)
+
+
[email protected]
+def test_pytorch_precision_float64():
+    """Verify optional float64 precision produces complex128 tensors."""
+    pytest.importorskip("torch")
+    import torch
+    from mahout_qdp import QdpEngine
+
+    engine = QdpEngine(0, precision="float64")
+    data = [1.0, 2.0, 3.0, 4.0]
+    qtensor = engine.encode(data, 2, "amplitude")
+
+    torch_tensor = torch.from_dlpack(qtensor)
+    assert torch_tensor.dtype == torch.complex128
diff --git a/qdp/qdp-python/tests/test_high_fidelity.py 
b/qdp/qdp-python/tests/test_high_fidelity.py
index 05bc41987..24f11c513 100644
--- a/qdp/qdp-python/tests/test_high_fidelity.py
+++ b/qdp/qdp-python/tests/test_high_fidelity.py
@@ -58,6 +58,15 @@ def engine():
         pytest.skip(f"CUDA initialization failed: {e}")
 
 
[email protected](scope="module")
+def engine_float64():
+    """High-precision engine for fidelity-sensitive tests."""
+    try:
+        return QdpEngine(0, precision="float64")
+    except RuntimeError as e:
+        pytest.skip(f"CUDA initialization failed: {e}")
+
+
 # 1. Core Logic and Boundary Tests
 
 
@@ -73,7 +82,9 @@ def engine():
         (20, 1_000_000, "Large - Async Pipeline"),
     ],
 )
-def test_amplitude_encoding_fidelity_comprehensive(engine, num_qubits, 
data_size, desc):
+def test_amplitude_encoding_fidelity_comprehensive(
+    engine_float64, num_qubits, data_size, desc
+):
     """Test fidelity across sync path, async pipeline, and chunk boundaries."""
     print(f"\n[Test Case] {desc} (Size: {data_size})")
 
@@ -87,7 +98,7 @@ def test_amplitude_encoding_fidelity_comprehensive(engine, 
num_qubits, data_size
         expected_state = np.concatenate([expected_state, padding])
 
     expected_state_complex = expected_state.astype(np.complex128)
-    qtensor = engine.encode(raw_data.tolist(), num_qubits, "amplitude")
+    qtensor = engine_float64.encode(raw_data.tolist(), num_qubits, "amplitude")
     torch_state = torch.from_dlpack(qtensor)
 
     assert torch_state.is_cuda, "Tensor must be on GPU"
@@ -110,6 +121,7 @@ def test_complex_integrity(engine):
     qtensor = engine.encode(raw_data.tolist(), num_qubits, "amplitude")
     torch_state = torch.from_dlpack(qtensor)
 
+    assert torch_state.dtype == torch.complex64
     imag_error = torch.sum(torch.abs(torch_state.imag)).item()
     print(f"\nSum of imaginary parts (should be near 0): {imag_error}")
 
@@ -123,12 +135,12 @@ def test_complex_integrity(engine):
 
 
 @pytest.mark.gpu
-def test_numerical_stability_underflow(engine):
+def test_numerical_stability_underflow(engine_float64):
     """Test precision with extremely small values (1e-150)."""
     num_qubits = 4
     data = [1e-150] * 16
 
-    qtensor = engine.encode(data, num_qubits, "amplitude")
+    qtensor = engine_float64.encode(data, num_qubits, "amplitude")
     torch_state = torch.from_dlpack(qtensor)
 
     assert not torch.isnan(torch_state).any(), "Result contains NaN for small 
inputs"

Reply via email to