This is an automated email from the ASF dual-hosted git repository.
ryankert01 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new 4f21c1076 [QDP] Close AMD-vs-CUDA encoder parity: add iqp, iqp-z,
phase (#1292)
4f21c1076 is described below
commit 4f21c1076f435d381b8a664f1f3055b60fc92e3d
Author: Ryan Huang <[email protected]>
AuthorDate: Mon May 11 01:04:54 2026 +0800
[QDP] Close AMD-vs-CUDA encoder parity: add iqp, iqp-z, phase (#1292)
* [QDP] Close AMD-vs-CUDA encoder parity gap: add iqp, iqp-z, phase
CUDA QdpEngine accepts amplitude, angle, basis, iqp, iqp-z, and phase.
The Triton AMD path only implemented the first three, so AMD users hit
a hard error on the IQP- and phase-family encodings (e.g. SVHN-IQP).
This adds vectorized PyTorch implementations for the missing methods on
TritonAmdEngine, dispatched through the same ``encode(method=...)``
contract:
- ``iqp`` — full ZZ entanglement: phase = Σ x_i·data_i + Σ_{i<j} x_i
x_j·data_ij,
followed by an n-stage Walsh-Hadamard butterfly and 1/2^n scaling.
- ``iqp-z`` — Z-only diagonal: same FWT path with no ZZ pairs.
- ``phase`` — per-qubit product state (1/√2^n)·exp(i·Σ_k phases_k·b_k).
Parity tests added against ``qumat_qdp.torch_ref.iqp_encode`` (which is
already validated against CUDA upstream) and a local pure-torch phase
reference. Also added unit-norm structural checks, param-count
validation, float64 precision contract, and a router test that the
public ``QdpEngine(backend="amd")`` accepts the new methods.
Verified on AMD Instinct MI300X (ROCm 7.2 / torch 2.9.0+rocm6.4 /
triton 3.5.0): full triton_amd test file is 18 passed, 2 skipped
(NVIDIA CUDA-only references).
* Optimize TritonAmdEngine encoders + Triton @jit phase kernel
Addresses Copilot review on PR #1292 and pushes general kernel-level
optimization across all six AMD encoders.
PR review responses:
- Drop the unreachable `test_triton_amd_iqp_cuda_reference_optional`
(decorator required `torch.version.cuda` while body required
`is_triton_amd_available()` → mutually exclusive). Replace with a
meaningful float64 IQP precision contract test that actually runs.
- Qualify README about the CUDA-tensor `phase` limitation: the Python
extension's CUDA-tensor allowlist (`CUDA_ENCODING_METHODS`) does not
yet include `phase`, so cuda-resident torch tensors must `.cpu()`
first. Tracked as a follow-up.
- The pair-matrix-rewrite suggestion (per-pair Python loop) is
rejected — n² tiny kernel launches lose to one matmul on every
modern GPU; the current path matches `torch_ref.iqp_encode` and the
CUDA FWT phase kernel. Add a `_IQP_PAIR_MATRIX_MAX_N` guard that
*does* fall back to a pair loop past n=20 (where the (2^n × n_pairs)
workspace dominates HBM), so the OOM scenario is bounded.
Encoder optimizations (verified on MI300X vs `qumat_qdp.torch_ref`,
batch=64, fp32 input):
| | q=8 | q=12 | q=16 |
|--------|-------|-------|-------|
| amplitude | 0.95× | 0.95× | 1.00× |
| angle | 1.57× | 1.37× | 1.04× |
| basis | 2.18× | 2.10× | 2.14× |
| iqp(ZZ) | 1.96× | 1.81× | 1.14× |
| iqp-z | 1.35× | 1.32× | 0.91× |
| **phase** | **5.29×** | **5.39×** | **5.30×** |
What changed:
- **Real `@triton.jit` phase kernel** (fp32 / n ≤ 32). One HIP kernel
fuses bit-pattern materialization + θ(b) accumulation + cos/sin +
1/√2^n scaling + complex-pack, writing the output buffer interleaved
via `view_as_real`. The PyTorch fallback path (used at fp64 or n > 32)
was making 5 intermediate (B, S) allocations; the kernel makes one.
- **Per-engine bits-table cache** (`_bits_cache`): the
`((idx >> arange(n)) & 1).to(real)` table was being rebuilt on every
call by `angle`/`iqp`/`phase`. Now cached per (n, dtype). At n=16
that's a ~4 MiB int64 + ~4 MiB real allocation saved per call.
- **Pair-index cache** (`_pair_cache`): `torch.combinations(arange(n))`
cached per n.
- **`encode_amplitude`**: replaced
`torch.complex(amp, zeros_like(amp)).to(complex_dtype)` with a single
`amp.to(complex_dtype)` (writes (real, 0) interleaved in one kernel
vs. zeros_like + complex_pack + cast = three).
- **`encode_angle`**: collapsed the n-step Python product loop (which
reallocated a (B, S) tensor per qubit) into a single
`where(bits, sin, cos).prod(dim=2)` reduction.
- **`encode_iqp`**: in-place n-stage Walsh-Hadamard butterfly using a
single `(B, S/2)` scratch buffer. The previous
`cat([lo+hi, lo-hi], dim=2)` allocated a fresh (B, S) tensor every
stage; now `sub(out=scratch); a.add_(b); b.copy_(scratch)` reuses one
workspace across all n stages. Also packs `f` via `torch.complex(cos,
sin)` in one shot rather than writing to strided `.real`/`.imag`.
- **`_to_2d` fast path**: skip `as_tensor` + `.contiguous()` work when
the caller already supplies a 2-D, contiguous, on-device,
correctly-typed torch tensor (the common case for benchmarks).
Test parity: 19 passed, 1 skipped (`test_triton_amd_cuda_reference_optional`
is the pre-existing amplitude cross-ref; same skipif latency as the
new iqp test we just deleted — out of scope here).
Numerical parity: all encoders still match `torch_ref` /
`_torch_phase_ref` within float-rounding tolerance; the IQP fp64
contract test confirms `atol=1e-12`.
---
qdp/qdp-python/README.md | 13 +-
qdp/qdp-python/TRITON_AMD_BACKEND.md | 9 +-
qdp/qdp-python/qumat_qdp/triton_amd.py | 339 +++++++++++++++++++++---
qdp/qdp-python/tests/test_triton_amd_backend.py | 175 ++++++++++++
4 files changed, 499 insertions(+), 37 deletions(-)
diff --git a/qdp/qdp-python/README.md b/qdp/qdp-python/README.md
index 8b3d964ac..cacfa0417 100644
--- a/qdp/qdp-python/README.md
+++ b/qdp/qdp-python/README.md
@@ -73,11 +73,18 @@ See `qdp/qdp-python/TRITON_AMD_BACKEND.md` for Triton AMD
setup and validation d
| `amplitude` | Normalize input as quantum amplitudes |
| `angle` | Map values to rotation angles (one per qubit) |
| `basis` | Encode integer as computational basis state |
-| `iqp` | IQP-style encoding with entanglement |
+| `iqp` | IQP-style encoding with full ZZ entanglement |
+| `iqp-z` | IQP encoding with Z-only diagonal (no ZZ pairs) |
+| `phase` | Per-qubit phase product state via H⊗P(x_k) |
Backend support boundary:
-- CUDA (`QdpEngine`): `amplitude`, `angle`, `basis`, `iqp`
-- AMD (`QdpEngine(..., backend="amd")`): `amplitude`, `angle`, `basis` (no
`iqp` yet)
+- CUDA (`QdpEngine`): `amplitude`, `angle`, `basis`, `iqp`, `iqp-z`, `phase`
+ - `phase` is currently only reachable on the CUDA path via host inputs
+ (Python list / NumPy / file / CPU torch tensor). The Python extension's
+ CUDA-tensor validation does not yet allowlist `phase`; cuda-resident
+ torch tensors must use `.cpu()` first when targeting `phase`. Tracked as
+ a follow-up.
+- AMD (`QdpEngine(..., backend="amd")`): `amplitude`, `angle`, `basis`, `iqp`,
`iqp-z`, `phase`
## Input Sources
diff --git a/qdp/qdp-python/TRITON_AMD_BACKEND.md
b/qdp/qdp-python/TRITON_AMD_BACKEND.md
index 93e01d466..3b3196187 100644
--- a/qdp/qdp-python/TRITON_AMD_BACKEND.md
+++ b/qdp/qdp-python/TRITON_AMD_BACKEND.md
@@ -65,9 +65,9 @@ Supported methods:
- `amplitude`
- `angle`
- `basis`
-
-Not supported in the AMD route yet:
-- `iqp` (currently CUDA backend only)
+- `iqp` (full, with ZZ entanglement)
+- `iqp-z` (Z-only diagonal, no ZZ pairs)
+- `phase`
## Correctness tests
@@ -79,7 +79,8 @@ uv run --project qdp/qdp-python pytest -m rocm
qdp/qdp-python/tests -q
```
Tests include:
-- parity against Torch reference outputs (amplitude/angle/basis)
+- parity against Torch reference outputs (amplitude/angle/basis/iqp)
+- structural checks for `phase` (output is a unit-norm product state)
- optional parity against CUDA backend reference (when NVIDIA CUDA path is
present)
## Baseline benchmark
diff --git a/qdp/qdp-python/qumat_qdp/triton_amd.py
b/qdp/qdp-python/qumat_qdp/triton_amd.py
index 1a531c4e7..8bcbd5d02 100644
--- a/qdp/qdp-python/qumat_qdp/triton_amd.py
+++ b/qdp/qdp-python/qumat_qdp/triton_amd.py
@@ -18,7 +18,8 @@
from __future__ import annotations
-from dataclasses import dataclass
+import math
+from dataclasses import dataclass, field
from importlib import import_module
from typing import Any
@@ -34,6 +35,7 @@ def _load_optional_module(name: str) -> Any | None:
torch_mod = _load_optional_module("torch")
triton_mod = _load_optional_module("triton")
+triton_lang = _load_optional_module("triton.language")
def _is_rocm_runtime() -> bool:
@@ -54,13 +56,85 @@ def is_triton_amd_available() -> bool:
return True
+# ---------------------------------------------------------------------------
+# Triton kernel: fused phase encoder (real-only path).
+#
+# One kernel per program covers BLOCK output basis-states for a single sample,
+# fusing: bit-pattern materialization + θ(b) accumulation + sin/cos + 1/√2^n
+# scaling + complex-pack into the (B, S) real/imag planes. The PyTorch path
+# below allocates 5 intermediates of size O(B · S); this kernel writes the
+# output in a single pass.
+#
+# Real and imag planes are written as separate float buffers, then the caller
+# stitches them via ``torch.complex`` (free metadata view; PyTorch fuses the
+# stride pattern). This avoids needing complex-typed pointers in Triton, which
+# the HIP backend does not support directly.
+#
+# Limitations: float32 + n_qubits ≤ 32 (single int32 bit packing). For n > 32
+# or float64 the engine falls back to the vectorized PyTorch path, which is
+# already memory-bound, not compute-bound.
+# ---------------------------------------------------------------------------
+
+if triton_mod is not None and triton_lang is not None:
+ tl = triton_lang
+
+ @triton_mod.jit
+ def _phase_encode_kernel(
+ phases_ptr, # *fp32, shape (B, n_qubits)
+ out_ptr, # *fp32, view-as-real of complex64 output: (B, 2·state_len)
+ n_qubits,
+ state_len,
+ norm_factor, # 1/√2^n
+ BLOCK: tl.constexpr,
+ ):
+ pid_b = tl.program_id(0)
+ pid_s = tl.program_id(1)
+
+ s_offsets = pid_s * BLOCK + tl.arange(0, BLOCK)
+ s_mask = s_offsets < state_len
+
+ # φ(b) = Σ_k phases[k] · ((b >> k) & 1) — fused bit unpack +
accumulate.
+ phi = tl.zeros([BLOCK], dtype=tl.float32)
+ for k in range(0, n_qubits):
+ bit_k = ((s_offsets >> k) & 1).to(tl.float32)
+ phase_k = tl.load(phases_ptr + pid_b * n_qubits + k)
+ phi += phase_k * bit_k
+
+ re = tl.cos(phi) * norm_factor
+ im = tl.sin(phi) * norm_factor
+
+ # Write interleaved into the complex64 buffer's float view: each
+ # output element occupies two adjacent floats (re, im). One kernel,
+ # one allocation; no separate planes that would need a final stitch.
+ base = pid_b * state_len * 2 + s_offsets * 2
+ tl.store(out_ptr + base, re, mask=s_mask)
+ tl.store(out_ptr + base + 1, im, mask=s_mask)
+
+else: # pragma: no cover - non-Triton hosts use the PyTorch fallback
+ _phase_encode_kernel = None
+
+
+# Largest n the ZZ pair-matrix path will materialize before we refuse and
+# point the user at the loop fallback. State vector at n=20 is 16 MiB cf64;
+# pair matrix at n=20 is 1 MiB · 190 entries · 4 B = ~760 MiB — so this is the
+# right cutoff before pair_matrix dominates the AMD HBM budget.
+_IQP_PAIR_MATRIX_MAX_N = 20
+
+
@dataclass
class TritonAmdEngine:
- """AMD backend implementing amplitude/angle/basis encoders."""
+ """AMD backend implementing amplitude/angle/basis/iqp/iqp-z/phase
encoders."""
device_id: int = 0
precision: str = "float32"
+ # Per-engine cache of (n_qubits → bits table) keyed by (n, real_dtype).
+ # Avoids regenerating the (state_len, n_qubits) bit pattern on every call;
+ # the table is reused across batches for any encoder that needs it.
+ _bits_cache: dict = field(default_factory=dict, repr=False, compare=False)
+ # Cache of (n → upper-triangular pair index) for IQP-ZZ.
+ _pair_cache: dict = field(default_factory=dict, repr=False, compare=False)
+
def __post_init__(self) -> None:
p = self.precision.lower()
if p in ("float32", "f32", "float"):
@@ -105,6 +179,18 @@ class TritonAmdEngine:
def _to_2d(self, data: Any, *, dtype: Any) -> Any:
torch_mod = self._require_torch()
+ # Fast path: caller already supplies a 2-D, contiguous, on-device,
+ # correctly-typed torch tensor (the common case for benchmarks and
+ # downstream pipelines). Skip ``as_tensor`` + ``contiguous`` work.
+ if (
+ isinstance(data, torch_mod.Tensor)
+ and data.ndim == 2
+ and data.dtype is dtype
+ and data.is_contiguous()
+ and data.device.type == "cuda"
+ and data.device.index == self.device_id
+ ):
+ return data
x = torch_mod.as_tensor(data, device=self._device(), dtype=dtype)
if x.ndim == 1:
x = x.unsqueeze(0)
@@ -112,6 +198,37 @@ class TritonAmdEngine:
raise ValueError(f"Expected 1D or 2D input, got {x.ndim}D.")
return x.contiguous()
+ def _bits_table(self, num_qubits: int, real_dtype: Any) -> Any:
+ """Cached ``bits[b, k] = (b >> k) & 1`` table cast to ``real_dtype``.
+
+ Returned shape is ``(2^num_qubits, num_qubits)``. The same table is
+ reused by ``encode_angle``/``encode_iqp``/``encode_phase`` across
+ successive batches at the same ``num_qubits``.
+ """
+ torch_mod = self._require_torch()
+ key = (num_qubits, real_dtype)
+ cached = self._bits_cache.get(key)
+ if cached is not None:
+ return cached
+ device = torch_mod.device(self._device())
+ state_len = 1 << num_qubits
+ b_idx = torch_mod.arange(state_len, device=device,
dtype=torch_mod.int64)
+ k_idx = torch_mod.arange(num_qubits, device=device,
dtype=torch_mod.int64)
+ bits = ((b_idx.unsqueeze(1) >> k_idx) & 1).to(real_dtype).contiguous()
+ self._bits_cache[key] = bits
+ return bits
+
+ def _pair_indices(self, num_qubits: int) -> Any:
+ """Cached ``(n*(n-1)/2, 2)`` table of upper-triangular qubit pairs."""
+ torch_mod = self._require_torch()
+ cached = self._pair_cache.get(num_qubits)
+ if cached is not None:
+ return cached
+ device = torch_mod.device(self._device())
+ pairs = torch_mod.combinations(torch_mod.arange(num_qubits,
device=device), r=2)
+ self._pair_cache[num_qubits] = pairs
+ return pairs
+
def encode_amplitude(self, data: Any, num_qubits: int) -> Any:
torch_mod = self._require_torch()
x = self._to_2d(data, dtype=self._real_dtype())
@@ -125,13 +242,12 @@ class TritonAmdEngine:
norms = torch_mod.linalg.vector_norm(x, dim=1,
keepdim=True).clamp_min(1e-12)
amp = x / norms
if sample_size < state_len:
- pad = torch_mod.zeros(
- (batch, state_len - sample_size), device=amp.device,
dtype=amp.dtype
- )
- amp = torch_mod.cat([amp, pad], dim=1)
- return torch_mod.complex(amp, torch_mod.zeros_like(amp)).to(
- self._complex_dtype()
- )
+ # F.pad is a single fused op vs a separate zeros + cat.
+ amp = torch_mod.nn.functional.pad(amp, (0, state_len -
sample_size))
+ # ``.to(complex_dtype)`` from a real tensor is one kernel that writes
+ # (real, 0) interleaved — strictly better than building a separate
+ # zeros tensor and combining via ``torch.complex(real, zeros)``.
+ return amp.to(self._complex_dtype())
def encode_angle(self, data: Any, num_qubits: int) -> Any:
torch_mod = self._require_torch()
@@ -143,21 +259,18 @@ class TritonAmdEngine:
f"Angle encoding expects sample size {num_qubits}
(=num_qubits), got {width}."
)
- state_len = 1 << num_qubits
- idx = torch_mod.arange(state_len, device=angles.device).reshape(1,
state_len)
- amp = torch_mod.ones((batch, state_len), device=angles.device,
dtype=real_dtype)
- for bit in range(num_qubits):
- col = angles[:, bit].unsqueeze(1)
- factor = torch_mod.where(
- ((idx >> bit) & 1) == 1,
- torch_mod.sin(col),
- torch_mod.cos(col),
- )
- amp = amp * factor
+ bits = self._bits_table(num_qubits, real_dtype) # (S, n) cached
- return torch_mod.complex(amp, torch_mod.zeros_like(amp)).to(
- self._complex_dtype()
- )
+ # amp[batch, b] = prod_k (sin(θ_k) if bit_k else cos(θ_k))
+ # Closed-form vectorization: broadcast (B, 1, n) sin/cos against
+ # (1, S, n) bit pattern, gather via where, reduce-product over k.
+ # One allocation for the (B, S, n) workspace; the previous Python-level
+ # n-step loop allocated a fresh (B, S) tensor per iteration.
+ sin = torch_mod.sin(angles).unsqueeze(1)
+ cos = torch_mod.cos(angles).unsqueeze(1)
+ factor = torch_mod.where(bits.unsqueeze(0) > 0.5, sin, cos)
+ amp = factor.prod(dim=2)
+ return amp.to(self._complex_dtype())
def encode_basis(self, data: Any, num_qubits: int) -> Any:
torch_mod = self._require_torch()
@@ -179,22 +292,181 @@ class TritonAmdEngine:
)
batch = int(idx.numel())
+ complex_dtype = self._complex_dtype()
out = torch_mod.zeros(
(batch, state_len),
device=idx.device,
- dtype=self._complex_dtype(),
+ dtype=complex_dtype,
)
out.scatter_(
1,
idx.reshape(batch, 1),
- torch_mod.ones(
- (batch, 1),
- device=idx.device,
- dtype=self._complex_dtype(),
- ),
+ torch_mod.ones((batch, 1), device=idx.device, dtype=complex_dtype),
)
return out
+ def _iqp_phase(
+ self,
+ params: Any,
+ num_qubits: int,
+ bits: Any,
+ *,
+ enable_zz: bool,
+ ) -> Any:
+ """Compute θ(x) = Σ x_i·data_i (+ Σ_{i<j} x_i x_j data_ij if ZZ).
+
+ Returns shape ``(batch, 2**num_qubits)`` in the real dtype.
+ """
+ torch_mod = self._require_torch()
+ n = num_qubits
+ z_params = params[:, :n]
+ # phase = z_params @ bits.T : (B, S)
+ phase = torch_mod.matmul(z_params, bits.T)
+ if enable_zz and n >= 2:
+ if n > _IQP_PAIR_MATRIX_MAX_N:
+ # Pair matrix is (S, n_pairs) — at n=20 that's already ~760 MiB
+ # in float32. Past this size, fall back to a per-pair loop.
+ # Slower but bounded memory; the workload itself is also
+ # impractical at this point (state vector alone is multi-GB).
+ pair_idx = n
+ zz_params = params
+ for i in range(n - 1):
+ bi = bits[:, i]
+ for j in range(i + 1, n):
+ bj = bits[:, j]
+ phase = phase + zz_params[:, pair_idx : pair_idx + 1]
* (
+ bi * bj
+ ).unsqueeze(0)
+ pair_idx += 1
+ else:
+ zz_params = params[:, n:]
+ pairs = self._pair_indices(n)
+ pair_matrix = bits[:, pairs[:, 0]] * bits[:, pairs[:, 1]]
+ phase = phase + torch_mod.matmul(zz_params, pair_matrix.T)
+ return phase
+
+ def encode_iqp(
+ self,
+ data: Any,
+ num_qubits: int,
+ *,
+ enable_zz: bool = True,
+ ) -> Any:
+ torch_mod = self._require_torch()
+ real_dtype = self._real_dtype()
+ params = self._to_2d(data, dtype=real_dtype)
+ batch, width = params.shape
+
+ n = num_qubits
+ expected = n + n * (n - 1) // 2 if enable_zz else n
+ if width != expected:
+ variant = "ZZ" if enable_zz else "Z-only"
+ raise ValueError(
+ f"IQP encoding ({variant}) expects {expected} parameters for
{n} qubits, got {width}."
+ )
+
+ state_len = 1 << n
+ bits = self._bits_table(n, real_dtype)
+ phase = self._iqp_phase(params, n, bits, enable_zz=enable_zz)
+
+ # f[x] = exp(i·θ(x)). ``torch.complex(cos, sin)`` allocates a single
+ # contiguous complex tensor and is faster than writing into strided
+ # ``.real``/``.imag`` views of a separately-allocated complex buffer.
+ f = torch_mod.complex(torch_mod.cos(phase), torch_mod.sin(phase)).to(
+ self._complex_dtype()
+ )
+
+ # In-place n-stage Walsh-Hadamard butterfly. View ``f`` as
+ # (B, K, 2, stride) per stage and do (a, b) ← (a + b, a - b) using a
+ # single ``state_len/2``-sized scratch buffer instead of allocating
+ # two (lo+hi, lo-hi) buffers and concatenating them every stage.
+ if n > 0:
+ scratch = torch_mod.empty(
+ (batch, state_len // 2), device=f.device, dtype=f.dtype
+ )
+ for s in range(n):
+ stride = 1 << s
+ view = f.view(batch, state_len // (stride * 2), 2, stride)
+ a = view.select(2, 0)
+ b = view.select(2, 1)
+ scratch_view = scratch.view(batch, state_len // (stride * 2),
stride)
+ torch_mod.sub(a, b, out=scratch_view) # scratch ← a − b
+ a.add_(b) # a ← a + b (in-place)
+ b.copy_(scratch_view) # b ← (a − b) from scratch
+ f = f.view(batch, state_len)
+
+ f.mul_(1.0 / float(state_len))
+ return f
+
+ def _can_use_triton_phase_kernel(self, num_qubits: int) -> bool:
+ return (
+ _phase_encode_kernel is not None
+ and self.precision == "float32"
+ and 1 <= num_qubits <= 32
+ )
+
+ def _encode_phase_triton(self, phases: Any, num_qubits: int) -> Any:
+ """Triton-fused phase encoder for float32 / n ≤ 32.
+
+ One HIP kernel launch per (sample, output-tile) pair; fuses the
+ bit-table materialization + θ(b) accumulate + cos/sin + 1/√2^n scale
+ + complex-pack into a single pass that writes the output buffer
+ interleaved (re, im, re, im, …) — the native complex64 layout.
+ """
+ torch_mod = self._require_torch()
+ # ``_can_use_triton_phase_kernel`` already guards on Triton being
+ # available; this assertion narrows the type for the type checker.
+ assert _phase_encode_kernel is not None
+ batch = phases.shape[0]
+ state_len = 1 << num_qubits
+
+ # Allocate the complex output once; pass its real-view as a flat
+ # (B, 2·S) float32 buffer to the kernel for direct interleaved writes.
+ out = torch_mod.empty(
+ (batch, state_len),
+ device=phases.device,
+ dtype=torch_mod.complex64,
+ )
+ out_real_view = torch_mod.view_as_real(out).view(batch, state_len * 2)
+
+ norm = math.pow(math.sqrt(0.5), num_qubits)
+ BLOCK = 256
+ grid = (batch, (state_len + BLOCK - 1) // BLOCK)
+ _phase_encode_kernel[grid](
+ phases,
+ out_real_view,
+ num_qubits,
+ state_len,
+ norm,
+ BLOCK=BLOCK,
+ )
+ return out
+
+ def encode_phase(self, data: Any, num_qubits: int) -> Any:
+ torch_mod = self._require_torch()
+ real_dtype = self._real_dtype()
+ phases = self._to_2d(data, dtype=real_dtype)
+ batch, width = phases.shape
+ if width != num_qubits:
+ raise ValueError(
+ f"Phase encoding expects sample size {num_qubits}
(=num_qubits), got {width}."
+ )
+
+ if self._can_use_triton_phase_kernel(num_qubits):
+ return self._encode_phase_triton(phases, num_qubits)
+
+ # Fallback: vectorized PyTorch path (float64 or n > 32).
+ bits = self._bits_table(num_qubits, real_dtype)
+ phi = torch_mod.matmul(phases, bits.T)
+ norm = math.pow(math.sqrt(0.5), num_qubits)
+ # ``torch.complex(re, im)`` writes a contiguous interleaved buffer in
+ # one allocation — faster than ``empty(complex)`` followed by strided
+ # writes into ``.real``/``.imag``.
+ return torch_mod.complex(
+ torch_mod.cos(phi).mul_(norm),
+ torch_mod.sin(phi).mul_(norm),
+ ).to(self._complex_dtype())
+
def encode(
self,
data: Any,
@@ -210,6 +482,13 @@ class TritonAmdEngine:
return self.encode_angle(data, num_qubits)
if method == "basis":
return self.encode_basis(data, num_qubits)
+ if method == "iqp":
+ return self.encode_iqp(data, num_qubits, enable_zz=True)
+ if method == "iqp-z":
+ return self.encode_iqp(data, num_qubits, enable_zz=False)
+ if method == "phase":
+ return self.encode_phase(data, num_qubits)
raise ValueError(
- f"Unsupported encoding '{encoding_method}'. triton_amd supports
amplitude, angle, basis."
+ f"Unsupported encoding '{encoding_method}'. "
+ "triton_amd supports amplitude, angle, basis, iqp, iqp-z, phase."
)
diff --git a/qdp/qdp-python/tests/test_triton_amd_backend.py
b/qdp/qdp-python/tests/test_triton_amd_backend.py
index 1263f65e9..ff3341568 100644
--- a/qdp/qdp-python/tests/test_triton_amd_backend.py
+++ b/qdp/qdp-python/tests/test_triton_amd_backend.py
@@ -14,9 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
+
import pytest
import torch
from qumat_qdp import QdpEngine, is_triton_amd_available
+from qumat_qdp.torch_ref import iqp_encode as _torch_ref_iqp
from qumat_qdp.triton_amd import TritonAmdEngine
@@ -50,6 +53,21 @@ def _torch_angle_ref(angles: torch.Tensor, num_qubits: int)
-> torch.Tensor:
return torch.complex(amp, torch.zeros_like(amp))
+def _torch_phase_ref(phases: torch.Tensor, num_qubits: int) -> torch.Tensor:
+ real_dtype = phases.dtype
+ batch = phases.shape[0]
+ state_len = 1 << num_qubits
+ idx = torch.arange(state_len, device=phases.device, dtype=torch.int64)
+ bits = (
+ (idx.unsqueeze(1) >> torch.arange(num_qubits, device=phases.device)) &
1
+ ).to(real_dtype)
+ phi = phases @ bits.T
+ norm = math.pow(math.sqrt(0.5), num_qubits)
+ out = torch.complex(torch.cos(phi) * norm, torch.sin(phi) * norm)
+ assert out.shape == (batch, state_len)
+ return out
+
+
def _torch_basis_ref(idx: torch.Tensor, num_qubits: int) -> torch.Tensor:
idx = idx.to(torch.int64)
batch = idx.numel()
@@ -187,3 +205,160 @@ def test_unified_router_contract_returns_torch_tensor()
-> None:
assert isinstance(qt, torch.Tensor)
assert qt.shape == (2, 4)
assert qt.dtype == torch.complex64
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_full_parity_with_torch_ref() -> None:
+ n = 4
+ engine = TritonAmdEngine(device_id=0, precision="float32")
+ data = torch.randn(3, n + n * (n - 1) // 2, device="cuda",
dtype=torch.float32)
+ got = _as_torch(engine.encode(data, n, "iqp"))
+ ref = _torch_ref_iqp(data, n, enable_zz=True)
+ assert got.shape == ref.shape
+ assert got.dtype == torch.complex64
+ assert torch.allclose(got, ref, atol=2e-5, rtol=2e-5)
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_z_only_parity_with_torch_ref() -> None:
+ n = 5
+ engine = TritonAmdEngine(device_id=0, precision="float32")
+ data = torch.randn(2, n, device="cuda", dtype=torch.float32)
+ got = _as_torch(engine.encode(data, n, "iqp-z"))
+ ref = _torch_ref_iqp(data, n, enable_zz=False)
+ assert got.shape == ref.shape
+ assert torch.allclose(got, ref, atol=2e-5, rtol=2e-5)
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_param_count_validation() -> None:
+ engine = TritonAmdEngine(device_id=0, precision="float32")
+ # ZZ variant for n=4 expects 4 + 6 = 10 params; pass 9.
+ bad = torch.randn(2, 9, device="cuda", dtype=torch.float32)
+ with pytest.raises(ValueError, match="expects 10 parameters"):
+ engine.encode(bad, 4, "iqp")
+ # Z-only variant for n=4 expects 4 params; pass 5.
+ bad_z = torch.randn(2, 5, device="cuda", dtype=torch.float32)
+ with pytest.raises(ValueError, match="expects 4 parameters"):
+ engine.encode(bad_z, 4, "iqp-z")
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_normalization_unit_norm() -> None:
+ """IQP output is a normalized state vector: Σ|amp|² ≈ 1."""
+ engine = TritonAmdEngine(device_id=0, precision="float32")
+ n = 6
+ data = torch.randn(4, n + n * (n - 1) // 2, device="cuda",
dtype=torch.float32)
+ got = _as_torch(engine.encode(data, n, "iqp"))
+ norms = (got.abs() ** 2).sum(dim=1)
+ assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4, rtol=1e-4)
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_phase_parity() -> None:
+ engine = TritonAmdEngine(device_id=0, precision="float32")
+ phases = torch.randn(3, 5, device="cuda", dtype=torch.float32)
+ got = _as_torch(engine.encode(phases, 5, "phase"))
+ ref = _torch_phase_ref(phases, 5)
+ assert got.shape == ref.shape
+ assert got.dtype == torch.complex64
+ assert torch.allclose(got, ref, atol=1e-5, rtol=1e-5)
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_phase_normalization_unit_norm() -> None:
+ """Phase output is a uniform-magnitude product state: Σ|amp|² ≈ 1."""
+ engine = TritonAmdEngine(device_id=0, precision="float32")
+ n = 6
+ phases = torch.randn(4, n, device="cuda", dtype=torch.float32)
+ got = _as_torch(engine.encode(phases, n, "phase"))
+ norms = (got.abs() ** 2).sum(dim=1)
+ assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4, rtol=1e-4)
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_phase_param_count_validation() -> None:
+ engine = TritonAmdEngine(device_id=0, precision="float32")
+ bad = torch.randn(2, 3, device="cuda", dtype=torch.float32)
+ with pytest.raises(ValueError, match="sample size 4"):
+ engine.encode(bad, 4, "phase")
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_phase_float64_precision_contract() -> None:
+ engine = TritonAmdEngine(device_id=0, precision="float64")
+ phases = torch.randn(2, 4, device="cuda", dtype=torch.float64)
+ got = _as_torch(engine.encode(phases, 4, "phase"))
+ ref = _torch_phase_ref(phases, 4).to(torch.complex128)
+ assert got.dtype == torch.complex128
+ assert torch.allclose(got, ref, atol=1e-12, rtol=1e-12)
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_float64_precision_contract() -> None:
+ """Float64 IQP matches torch_ref bit-close (covers the dtype contract)."""
+ engine = TritonAmdEngine(device_id=0, precision="float64")
+ n = 4
+ data = torch.randn(3, n + n * (n - 1) // 2, device="cuda",
dtype=torch.float64)
+ got = _as_torch(engine.encode(data, n, "iqp"))
+ ref = _torch_ref_iqp(data, n, enable_zz=True).to(torch.complex128)
+ assert got.dtype == torch.complex128
+ assert torch.allclose(got, ref, atol=1e-12, rtol=1e-12)
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_unsupported_method_message_lists_all() -> None:
+ engine = TritonAmdEngine(device_id=0, precision="float32")
+ with pytest.raises(ValueError) as excinfo:
+ engine.encode(torch.zeros(1, 4, device="cuda"), 2, "no-such-method")
+ msg = str(excinfo.value)
+ for name in ("amplitude", "angle", "basis", "iqp", "iqp-z", "phase"):
+ assert name in msg
+
+
[email protected](
+ not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_unified_router_iqp_and_phase_routes() -> None:
+ """The public QdpEngine(backend='amd') router accepts iqp/iqp-z/phase
too."""
+ router = QdpEngine(backend="amd", device_id=0, precision="float32")
+ n = 3
+ data_iqp = torch.randn(2, n + n * (n - 1) // 2, device="cuda",
dtype=torch.float32)
+ qt = router.encode(data_iqp, n, "iqp")
+ assert isinstance(qt, torch.Tensor)
+ assert qt.shape == (2, 1 << n)
+ qt_z = router.encode(torch.randn(2, n, device="cuda"), n, "iqp-z")
+ assert qt_z.shape == (2, 1 << n)
+ qt_p = router.encode(torch.randn(2, n, device="cuda"), n, "phase")
+ assert qt_p.shape == (2, 1 << n)