This is an automated email from the ASF dual-hosted git repository.

hcr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new b68cdfdd9 Enable CUDA tensor validation for iqp and iqp-z encodings 
(#1093)
b68cdfdd9 is described below

commit b68cdfdd98dee797ad8fff2c6de4fb39368d785a
Author: Suyash Parmar <[email protected]>
AuthorDate: Tue Mar 3 06:13:16 2026 +0530

    Enable CUDA tensor validation for iqp and iqp-z encodings (#1093)
    
    Co-authored-by: Suyash Parmar <[email protected]>
    Co-authored-by: Ryan Huang <[email protected]>
---
 qdp/qdp-python/src/pytorch.rs | 11 ++++++++++-
 testing/qdp/test_bindings.py  | 37 +++++++++++++++++++++++++++++++------
 2 files changed, 41 insertions(+), 7 deletions(-)

diff --git a/qdp/qdp-python/src/pytorch.rs b/qdp/qdp-python/src/pytorch.rs
index cb5c75247..538871292 100644
--- a/qdp/qdp-python/src/pytorch.rs
+++ b/qdp/qdp-python/src/pytorch.rs
@@ -182,9 +182,18 @@ pub fn validate_cuda_tensor_for_encoding(
                 )));
             }
         }
+        "iqp" | "iqp-z" => {
+            if !dtype_str_lower.contains("float64") {
+                return Err(PyRuntimeError::new_err(format!(
+                    "CUDA tensor must have dtype float64 for {} encoding, got 
{}. \
+                     Use tensor.to(torch.float64)",
+                    method, dtype_str
+                )));
+            }
+        }
         _ => {
             return Err(PyRuntimeError::new_err(format!(
-                "CUDA tensor encoding currently only supports 'amplitude', 
'angle', or 'basis' methods, got '{}'. \
+                "CUDA tensor encoding currently only supports 'amplitude', 
'angle', 'basis', 'iqp', or 'iqp-z' methods, got '{}'. \
                  Use tensor.cpu() to convert to CPU tensor for other encoding 
methods.",
                 encoding_method
             )));
diff --git a/testing/qdp/test_bindings.py b/testing/qdp/test_bindings.py
index 0bc971d1c..13916e6d7 100644
--- a/testing/qdp/test_bindings.py
+++ b/testing/qdp/test_bindings.py
@@ -422,9 +422,34 @@ def test_encode_cuda_tensor_preserves_input(data_shape, 
is_batch):
 
 @requires_qdp
 @pytest.mark.gpu
[email protected]("encoding_method", ["iqp"])
-def test_encode_cuda_tensor_unsupported_encoding(encoding_method):
-    """Test error when using CUDA tensor with an encoding not supported on GPU 
(only amplitude, angle, basis)."""
[email protected](
+    "encoding_method,data",
+    [
+        ("iqp-z", [0.1, -0.2]),
+        ("iqp", [0.1, -0.2, 0.3]),
+    ],
+)
+def test_encode_cuda_tensor_iqp_methods(encoding_method, data):
+    """Test CUDA tensor path supports IQP-family encodings."""
+    pytest.importorskip("torch")
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+    cuda_data = torch.tensor(data, dtype=torch.float64, device="cuda:0")
+
+    qtensor = engine.encode(cuda_data, 2, encoding_method)
+    output = torch.from_dlpack(qtensor)
+
+    assert tuple(output.shape) == (1, 4)
+
+
+@requires_qdp
[email protected]
+def test_encode_cuda_tensor_invalid_encoding_method():
+    """Test error when using CUDA tensor with an unknown encoding method."""
     pytest.importorskip("torch")
     from _qdp import QdpEngine
 
@@ -433,14 +458,14 @@ def 
test_encode_cuda_tensor_unsupported_encoding(encoding_method):
 
     engine = QdpEngine(0)
 
-    # CUDA path only supports amplitude, angle, basis; iqp/iqp-z should raise 
unsupported error
+    # Unknown encoding should fail with supported-method guidance.
     data = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float64, 
device="cuda:0")
 
     with pytest.raises(
         RuntimeError,
-        match="only supports .*amplitude.*angle.*basis.*Use tensor.cpu",
+        match="only supports .*amplitude.*angle.*basis.*iqp.*iqp-z.*Use 
tensor.cpu",
     ):
-        engine.encode(data, 2, encoding_method)
+        engine.encode(data, 2, "unknown-encoding")
 
 
 @requires_qdp

Reply via email to