This is an automated email from the ASF dual-hosted git repository.
hcr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new 4eaeb0809 Fix/dynamic cuda encoding method list (#1100)
4eaeb0809 is described below
commit 4eaeb0809d14d7a069854b16ed2e40b64c07df9b
Author: Suyash Parmar <[email protected]>
AuthorDate: Wed Mar 4 12:51:38 2026 +0530
Fix/dynamic cuda encoding method list (#1100)
* Enable CUDA tensor validation for iqp and iqp-z encodings
* Make CUDA unsupported-method error list dynamic
* Refine CUDA method-list formatter per review feedback
* Move CUDA encoding constants into shared module
* Use centralized method list as CUDA support gate
---------
Co-authored-by: Suyash Parmar <[email protected]>
Co-authored-by: Ryan Huang <[email protected]>
---
qdp/qdp-python/src/constants.rs | 32 ++++++++++++++++++++++++++++++++
qdp/qdp-python/src/lib.rs | 1 +
qdp/qdp-python/src/pytorch.rs | 31 ++++++++++++++++---------------
3 files changed, 49 insertions(+), 15 deletions(-)
diff --git a/qdp/qdp-python/src/constants.rs b/qdp/qdp-python/src/constants.rs
new file mode 100644
index 000000000..433c6f03c
--- /dev/null
+++ b/qdp/qdp-python/src/constants.rs
@@ -0,0 +1,32 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+pub const CUDA_ENCODING_METHODS: &[&str] = &["amplitude", "angle", "basis",
"iqp", "iqp-z"];
+
+pub fn format_supported_cuda_encoding_methods() -> String {
+ match CUDA_ENCODING_METHODS.split_last() {
+ None => String::new(),
+ Some((last, [])) => format!("'{}'", last),
+ Some((last, rest)) => {
+ let rest = rest
+ .iter()
+ .map(|method| format!("'{}'", method))
+ .collect::<Vec<_>>()
+ .join(", ");
+ format!("{}, or '{}'", rest, last)
+ }
+ }
+}
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 6365992a6..014f992ba 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -14,6 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+mod constants;
mod dlpack;
mod engine;
mod loader;
diff --git a/qdp/qdp-python/src/pytorch.rs b/qdp/qdp-python/src/pytorch.rs
index 538871292..17185fd54 100644
--- a/qdp/qdp-python/src/pytorch.rs
+++ b/qdp/qdp-python/src/pytorch.rs
@@ -18,6 +18,8 @@ use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use std::ffi::c_void;
+use crate::constants::{CUDA_ENCODING_METHODS,
format_supported_cuda_encoding_methods};
+
/// Helper to detect PyTorch tensor
pub fn is_pytorch_tensor(obj: &Bound<'_, PyAny>) -> PyResult<bool> {
let type_obj = obj.get_type();
@@ -150,6 +152,15 @@ pub fn validate_cuda_tensor_for_encoding(
) -> PyResult<()> {
let method = encoding_method.to_ascii_lowercase();
+ if !CUDA_ENCODING_METHODS.contains(&method.as_str()) {
+ return Err(PyRuntimeError::new_err(format!(
+ "CUDA tensor encoding currently only supports {} methods, got
'{}'. \
+ Use tensor.cpu() to convert to CPU tensor for other encoding
methods.",
+ format_supported_cuda_encoding_methods(),
+ encoding_method
+ )));
+ }
+
// Check encoding method support and dtype (ASCII lowercase for
case-insensitive match).
let dtype = tensor.getattr("dtype")?;
let dtype_str: String = dtype.str()?.extract()?;
@@ -164,12 +175,12 @@ pub fn validate_cuda_tensor_for_encoding(
)));
}
}
- "angle" => {
+ "angle" | "iqp" | "iqp-z" => {
if !dtype_str_lower.contains("float64") {
return Err(PyRuntimeError::new_err(format!(
- "CUDA tensor must have dtype float64 for angle encoding,
got {}. \
+ "CUDA tensor must have dtype float64 for {} encoding, got
{}. \
Use tensor.to(torch.float64)",
- dtype_str
+ method, dtype_str
)));
}
}
@@ -182,20 +193,10 @@ pub fn validate_cuda_tensor_for_encoding(
)));
}
}
- "iqp" | "iqp-z" => {
- if !dtype_str_lower.contains("float64") {
- return Err(PyRuntimeError::new_err(format!(
- "CUDA tensor must have dtype float64 for {} encoding, got
{}. \
- Use tensor.to(torch.float64)",
- method, dtype_str
- )));
- }
- }
_ => {
return Err(PyRuntimeError::new_err(format!(
- "CUDA tensor encoding currently only supports 'amplitude',
'angle', 'basis', 'iqp', or 'iqp-z' methods, got '{}'. \
- Use tensor.cpu() to convert to CPU tensor for other encoding
methods.",
- encoding_method
+ "Internal error: missing CUDA validation branch for supported
method '{}'",
+ method
)));
}
}