viiccwen opened a new pull request, #1025:
URL: https://github.com/apache/mahout/pull/1025
### Purpose of PR
This PR wires the existing core `encode_from_gpu_ptr_f32_with_stream` API
into the Python `_qdp` bindings, enabling amplitude encoding directly from 1D
float32 CUDA tensors.
### Changes
- **Validation:** `validate_cuda_tensor_for_encoding` now accepts both
`float64` and `float32` CUDA tensors for amplitude encoding; angle and basis
requirements are unchanged.
- **Dispatch:** In `QdpEngine.encode`, when the input is a CUDA tensor with
`encoding_method="amplitude"`:
- **1D float32:** Calls `encode_from_gpu_ptr_f32_with_stream` using the
tensor’s `data_ptr()`, `numel()`, and the current CUDA stream.
- **2D float32:** Returns a clear `RuntimeError` that batch float32
amplitude is not yet supported (suggests float64 or per-sample encoding).
- **Refactor:** CUDA tensor handling is moved into an internal helper
`_encode_from_cuda_tensor` (placed after `_encode_stream_internal`) to keep
`encode` readable.
### Example
```py
import torch
from _qdp import QdpEngine
engine = QdpEngine(0)
# This path was already supported (host list -> core encode)
data_list = [1.0, 2.0, 3.0, 4.0]
qtensor_list = engine.encode(data_list, 2, "amplitude")
print("From list:", qtensor_list)
# This path is what this PR adds: 1D float32 CUDA tensor ->
encode_from_gpu_ptr_f32_with_stream
data_cuda_f32 = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32,
device="cuda:0")
qtensor_f32 = engine.encode(data_cuda_f32, 2, "amplitude")
print("From 1D float32 CUDA:", qtensor_f32)
# Verify DLPack round-trip
t = torch.from_dlpack(qtensor_f32)
print(" shape:", t.shape, "dtype:", t.dtype, "device:", t.device)
```
```
From list: <builtins.QuantumTensor object at 0x7f51b6597d20>
From 1D float32 CUDA: <builtins.QuantumTensor object at 0x7f51b6597cf0>
shape: torch.Size([1, 4]) dtype: torch.complex64 device: cuda:0
```
### Related Issues or PRs
closes #1024
### Changes Made
<!-- Please mark one with an "x" -->
- [ ] Bug fix
- [x] New feature
- [ ] Refactoring
- [ ] Documentation
- [x] Test
- [ ] CI/CD pipeline
- [ ] Other
### Breaking Changes
<!-- Does this PR introduce a breaking change? -->
- [x] Yes
- [ ] No
### Checklist
<!-- Please mark each item with an "x" when complete -->
<!-- If not all items are complete, please open this as a **Draft PR**.
Once all requirements are met, mark as ready for review. -->
- [x] Added or updated unit tests for all changes
- [ ] Added or updated documentation for all changes
- [x] Successfully built and ran all unit tests or manual tests locally
- [x] PR title follows "MAHOUT-XXX: Brief Description" format (if related to
an issue)
- [x] Code follows ASF guidelines
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]