This is an automated email from the ASF dual-hosted git repository.
guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new 47d899b8d [QDP] Enhance numerical safety checks in preprocessor (#805)
47d899b8d is described below
commit 47d899b8da3126925b53ab615922d32b542e3f30
Author: Vic Wen <[email protected]>
AuthorDate: Thu Jan 15 10:39:11 2026 +0800
[QDP] Enhance numerical safety checks in preprocessor (#805)
* feat: enhance numerical safety checks in preprocessor
1. Added validation for NaN and Infinity values in input data and L2 norm
calculations.
2. Updated correlated tests to cover these scenarios.
* refactor: improve L2 norm calculation and numerical safety checks
1. Removed redundant numerical safety check before L2 norm calculation,
leveraging IEEE 754 propagation for NaN/Infinity.
2. Enhanced the check_numerical_safety function to use parallel iteration
for efficiency.
* fix: unify NaN and Infinity checks in numerical safety validation
Updated the check_numerical_safety function to combine checks for NaN and
Infinity values into a single validation, improving clarity and reducing
redundancy in error handling.
* refactor: optimize numerical safety checks for debug builds
Modified the check_numerical_safety function to only perform NaN and
Infinity checks in debug builds, reducing overhead in release mode while
maintaining data integrity validation.
* refactor: remove redundant numerical-safety checks
* test: add tests for L2 norm calculations with invalid values
* chore: add approx crate as a dependency for qdp-core
* test: update assertions to use approx crate for relative equality checks;
undo deleting comments
---
qdp/Cargo.lock | 10 +++++++
qdp/qdp-core/Cargo.toml | 3 +++
qdp/qdp-core/src/preprocessing.rs | 25 +++++++++++++++++
qdp/qdp-core/tests/preprocessing.rs | 54 ++++++++++++++++++++++++++++++++++---
4 files changed, 88 insertions(+), 4 deletions(-)
diff --git a/qdp/Cargo.lock b/qdp/Cargo.lock
index d316707e3..e175e1f60 100644
--- a/qdp/Cargo.lock
+++ b/qdp/Cargo.lock
@@ -61,6 +61,15 @@ version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
+[[package]]
+name = "approx"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
+dependencies = [
+ "num-traits",
+]
+
[[package]]
name = "arbitrary"
version = "1.4.2"
@@ -1374,6 +1383,7 @@ dependencies = [
name = "qdp-core"
version = "0.1.0"
dependencies = [
+ "approx",
"arrow",
"bytes",
"cudarc",
diff --git a/qdp/qdp-core/Cargo.toml b/qdp/qdp-core/Cargo.toml
index c4c27533c..1c92f0332 100644
--- a/qdp/qdp-core/Cargo.toml
+++ b/qdp/qdp-core/Cargo.toml
@@ -26,3 +26,6 @@ name = "qdp_core"
[features]
default = []
observability = ["nvtx"]
+
+[dev-dependencies]
+approx = "0.5.1"
diff --git a/qdp/qdp-core/src/preprocessing.rs
b/qdp/qdp-core/src/preprocessing.rs
index c790febf2..0369f6f62 100644
--- a/qdp/qdp-core/src/preprocessing.rs
+++ b/qdp/qdp-core/src/preprocessing.rs
@@ -79,6 +79,18 @@ impl Preprocessor {
));
}
+ if norm.is_nan() {
+ return Err(MahoutError::InvalidInput(
+ "Input data contains NaN (Not a Number) values".to_string(),
+ ));
+ }
+
+ if norm.is_infinite() {
+ return Err(MahoutError::InvalidInput(
+ "Input data contains Infinity values".to_string(),
+ ));
+ }
+
Ok(norm)
}
@@ -143,6 +155,19 @@ impl Preprocessor {
i
)));
}
+ // Check result for NaN and Infinity
+ if norm.is_nan() {
+ return Err(MahoutError::InvalidInput(format!(
+ "Sample {} produced NaN norm",
+ i
+ )));
+ }
+ if norm.is_infinite() {
+ return Err(MahoutError::InvalidInput(format!(
+ "Sample {} produced Infinity norm",
+ i
+ )));
+ }
Ok(norm)
})
.collect()
diff --git a/qdp/qdp-core/tests/preprocessing.rs
b/qdp/qdp-core/tests/preprocessing.rs
index bd1958308..4b3b16c3d 100644
--- a/qdp/qdp-core/tests/preprocessing.rs
+++ b/qdp/qdp-core/tests/preprocessing.rs
@@ -14,6 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+use approx::assert_relative_eq;
use qdp_core::MahoutError;
use qdp_core::preprocessing::Preprocessor;
@@ -76,11 +77,11 @@ fn test_validate_input_max_qubits_boundary() {
fn test_calculate_l2_norm_success() {
let data = vec![3.0, 4.0];
let norm = Preprocessor::calculate_l2_norm(&data).unwrap();
- assert!((norm - 5.0).abs() < 1e-10);
+ assert_relative_eq!(norm, 5.0);
let data = vec![1.0, 1.0];
let norm = Preprocessor::calculate_l2_norm(&data).unwrap();
- assert!((norm - 2.0_f64.sqrt()).abs() < 1e-10);
+ assert_relative_eq!(norm, 2.0_f64.sqrt());
}
#[test]
@@ -94,7 +95,7 @@ fn test_calculate_l2_norm_zero() {
fn test_calculate_l2_norm_mixed_signs() {
let data = vec![-3.0, 4.0];
let norm = Preprocessor::calculate_l2_norm(&data).unwrap();
- assert!((norm - 5.0).abs() < 1e-10);
+ assert_relative_eq!(norm, 5.0);
}
#[test]
@@ -103,5 +104,50 @@ fn test_calculate_l2_norm_matches_sequential_sum() {
let norm_parallel = Preprocessor::calculate_l2_norm(&data).unwrap();
let norm_sequential = data.iter().map(|x| x * x).sum::<f64>().sqrt();
- assert!((norm_parallel - norm_sequential).abs() < 1e-10);
+ assert_relative_eq!(norm_parallel, norm_sequential);
+}
+
+#[test]
+fn test_calculate_l2_norm_invalid_values() {
+ let cases = [
+ ("NaN", f64::NAN, "NaN"),
+ ("+Inf", f64::INFINITY, "Infinity"),
+ ("-Inf", f64::NEG_INFINITY, "Infinity"),
+ ];
+
+ for (label, bad_value, expected_fragment) in cases {
+ let data = vec![1.0, bad_value, 3.0];
+ let result = Preprocessor::calculate_l2_norm(&data);
+ assert!(
+ matches!(result, Err(MahoutError::InvalidInput(msg)) if
msg.contains(expected_fragment)),
+ "case {label} did not produce expected error"
+ );
+ }
+}
+
+#[test]
+fn test_calculate_batch_l2_norms_invalid_values() {
+ let cases = [
+ ("NaN", f64::NAN, "NaN"),
+ ("+Inf", f64::INFINITY, "Infinity"),
+ ("-Inf", f64::NEG_INFINITY, "Infinity"),
+ ];
+
+ for (label, bad_value, expected_fragment) in cases {
+ let batch_data = vec![1.0, 2.0, bad_value, 4.0];
+ let result = Preprocessor::calculate_batch_l2_norms(&batch_data, 2, 2);
+ assert!(
+ matches!(result, Err(MahoutError::InvalidInput(msg)) if
msg.contains(expected_fragment)),
+ "case {label} did not produce expected error"
+ );
+ }
+}
+
+#[test]
+fn test_calculate_batch_l2_norms_success() {
+ let batch_data = vec![3.0, 4.0, 5.0, 12.0];
+ let norms = Preprocessor::calculate_batch_l2_norms(&batch_data, 2,
2).unwrap();
+ assert_eq!(norms.len(), 2);
+ assert_relative_eq!(norms[0], 5.0); // sqrt(3^2 + 4^2) = 5
+ assert_relative_eq!(norms[1], 13.0); // sqrt(5^2 + 12^2) = 13
}