viiccwen opened a new pull request, #1025:
URL: https://github.com/apache/mahout/pull/1025

   ### Purpose of PR
   This PR wires the existing core `encode_from_gpu_ptr_f32_with_stream` API 
into the Python `_qdp` bindings, enabling amplitude encoding directly from 1D 
float32 CUDA tensors.
   
   ### Changes
   - **Validation:** `validate_cuda_tensor_for_encoding` now accepts both 
`float64` and `float32` CUDA tensors for amplitude encoding; angle and basis 
requirements are unchanged.
   - **Dispatch:** In `QdpEngine.encode`, when the input is a CUDA tensor with 
`encoding_method="amplitude"`:
     - **1D float32:** Calls `encode_from_gpu_ptr_f32_with_stream` using the 
tensor’s `data_ptr()`, `numel()`, and the current CUDA stream.
     - **2D float32:** Returns a clear `RuntimeError` that batch float32 
amplitude is not yet supported (suggests float64 or per-sample encoding).
   - **Refactor:** CUDA tensor handling is moved into an internal helper 
`_encode_from_cuda_tensor` (placed after `_encode_stream_internal`) to keep 
`encode` readable.
   
   ### Example 
   
   ```py
   import torch
   from _qdp import QdpEngine
   
   engine = QdpEngine(0)
   
   # This path was already supported (host list -> core encode)
   data_list = [1.0, 2.0, 3.0, 4.0]
   qtensor_list = engine.encode(data_list, 2, "amplitude")
   print("From list:", qtensor_list)
   
   # This path is what this PR adds: 1D float32 CUDA tensor -> 
encode_from_gpu_ptr_f32_with_stream
   data_cuda_f32 = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32, 
device="cuda:0")
   qtensor_f32 = engine.encode(data_cuda_f32, 2, "amplitude")
   print("From 1D float32 CUDA:", qtensor_f32)
   
   # Verify DLPack round-trip
   t = torch.from_dlpack(qtensor_f32)
   print("  shape:", t.shape, "dtype:", t.dtype, "device:", t.device)
   ```
   
   ```
   From list: <builtins.QuantumTensor object at 0x7f51b6597d20>
   From 1D float32 CUDA: <builtins.QuantumTensor object at 0x7f51b6597cf0>
     shape: torch.Size([1, 4]) dtype: torch.complex64 device: cuda:0
   ```
   
   ### Related Issues or PRs
   closes #1024
   
   ### Changes Made
   <!-- Please mark one with an "x"   -->
   - [ ] Bug fix
   - [x] New feature
   - [ ] Refactoring
   - [ ] Documentation
   - [x] Test
   - [ ] CI/CD pipeline
   - [ ] Other
   
   ### Breaking Changes
   <!-- Does this PR introduce a breaking change? -->
   - [x] Yes
   - [ ] No
   
   ### Checklist
   <!-- Please mark each item with an "x" when complete -->
   <!-- If not all items are complete, please open this as a **Draft PR**.
   Once all requirements are met, mark as ready for review. -->
   
   - [x] Added or updated unit tests for all changes
   - [ ] Added or updated documentation for all changes
   - [x] Successfully built and ran all unit tests or manual tests locally
   - [x] PR title follows "MAHOUT-XXX: Brief Description" format (if related to 
an issue)
   - [x] Code follows ASF guidelines
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to