This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new 47d899b8d [QDP] Enhance numerical safety checks in preprocessor (#805)
47d899b8d is described below

commit 47d899b8da3126925b53ab615922d32b542e3f30
Author: Vic Wen <[email protected]>
AuthorDate: Thu Jan 15 10:39:11 2026 +0800

    [QDP] Enhance numerical safety checks in preprocessor (#805)
    
    * feat: enhance numerical safety checks in preprocessor
    
    1. Added validation for NaN and Infinity values in input data and L2 norm 
calculations.
    2. Updated correlated tests to cover these scenarios.
    
    * refactor: improve L2 norm calculation and numerical safety checks
    
    1. Removed redundant numerical safety check before L2 norm calculation, 
leveraging IEEE 754 propagation for NaN/Infinity.
    2. Enhanced the check_numerical_safety function to use parallel iteration 
for efficiency.
    
    * fix: unify NaN and Infinity checks in numerical safety validation
    
    Updated the check_numerical_safety function to combine checks for NaN and 
Infinity values into a single validation, improving clarity and reducing 
redundancy in error handling.
    
    * refactor: optimize numerical safety checks for debug builds
    
    Modified the check_numerical_safety function to only perform NaN and 
Infinity checks in debug builds, reducing overhead in release mode while 
maintaining data integrity validation.
    
    * refactor: remove redundant numerical-safety checks
    
    * test: add tests for L2 norm calculations with invalid values
    
    * chore: add approx crate as a dependency for qdp-core
    
    * test: update assertions to use approx crate for relative equality checks; 
undo deleting comments
---
 qdp/Cargo.lock                      | 10 +++++++
 qdp/qdp-core/Cargo.toml             |  3 +++
 qdp/qdp-core/src/preprocessing.rs   | 25 +++++++++++++++++
 qdp/qdp-core/tests/preprocessing.rs | 54 ++++++++++++++++++++++++++++++++++---
 4 files changed, 88 insertions(+), 4 deletions(-)

diff --git a/qdp/Cargo.lock b/qdp/Cargo.lock
index d316707e3..e175e1f60 100644
--- a/qdp/Cargo.lock
+++ b/qdp/Cargo.lock
@@ -61,6 +61,15 @@ version = "1.0.100"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
 
+[[package]]
+name = "approx"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
+dependencies = [
+ "num-traits",
+]
+
 [[package]]
 name = "arbitrary"
 version = "1.4.2"
@@ -1374,6 +1383,7 @@ dependencies = [
 name = "qdp-core"
 version = "0.1.0"
 dependencies = [
+ "approx",
  "arrow",
  "bytes",
  "cudarc",
diff --git a/qdp/qdp-core/Cargo.toml b/qdp/qdp-core/Cargo.toml
index c4c27533c..1c92f0332 100644
--- a/qdp/qdp-core/Cargo.toml
+++ b/qdp/qdp-core/Cargo.toml
@@ -26,3 +26,6 @@ name = "qdp_core"
 [features]
 default = []
 observability = ["nvtx"]
+
+[dev-dependencies]
+approx = "0.5.1"
diff --git a/qdp/qdp-core/src/preprocessing.rs 
b/qdp/qdp-core/src/preprocessing.rs
index c790febf2..0369f6f62 100644
--- a/qdp/qdp-core/src/preprocessing.rs
+++ b/qdp/qdp-core/src/preprocessing.rs
@@ -79,6 +79,18 @@ impl Preprocessor {
             ));
         }
 
+        if norm.is_nan() {
+            return Err(MahoutError::InvalidInput(
+                "Input data contains NaN (Not a Number) values".to_string(),
+            ));
+        }
+
+        if norm.is_infinite() {
+            return Err(MahoutError::InvalidInput(
+                "Input data contains Infinity values".to_string(),
+            ));
+        }
+
         Ok(norm)
     }
 
@@ -143,6 +155,19 @@ impl Preprocessor {
                         i
                     )));
                 }
+                // Check result for NaN and Infinity
+                if norm.is_nan() {
+                    return Err(MahoutError::InvalidInput(format!(
+                        "Sample {} produced NaN norm",
+                        i
+                    )));
+                }
+                if norm.is_infinite() {
+                    return Err(MahoutError::InvalidInput(format!(
+                        "Sample {} produced Infinity norm",
+                        i
+                    )));
+                }
                 Ok(norm)
             })
             .collect()
diff --git a/qdp/qdp-core/tests/preprocessing.rs 
b/qdp/qdp-core/tests/preprocessing.rs
index bd1958308..4b3b16c3d 100644
--- a/qdp/qdp-core/tests/preprocessing.rs
+++ b/qdp/qdp-core/tests/preprocessing.rs
@@ -14,6 +14,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+use approx::assert_relative_eq;
 use qdp_core::MahoutError;
 use qdp_core::preprocessing::Preprocessor;
 
@@ -76,11 +77,11 @@ fn test_validate_input_max_qubits_boundary() {
 fn test_calculate_l2_norm_success() {
     let data = vec![3.0, 4.0];
     let norm = Preprocessor::calculate_l2_norm(&data).unwrap();
-    assert!((norm - 5.0).abs() < 1e-10);
+    assert_relative_eq!(norm, 5.0);
 
     let data = vec![1.0, 1.0];
     let norm = Preprocessor::calculate_l2_norm(&data).unwrap();
-    assert!((norm - 2.0_f64.sqrt()).abs() < 1e-10);
+    assert_relative_eq!(norm, 2.0_f64.sqrt());
 }
 
 #[test]
@@ -94,7 +95,7 @@ fn test_calculate_l2_norm_zero() {
 fn test_calculate_l2_norm_mixed_signs() {
     let data = vec![-3.0, 4.0];
     let norm = Preprocessor::calculate_l2_norm(&data).unwrap();
-    assert!((norm - 5.0).abs() < 1e-10);
+    assert_relative_eq!(norm, 5.0);
 }
 
 #[test]
@@ -103,5 +104,50 @@ fn test_calculate_l2_norm_matches_sequential_sum() {
     let norm_parallel = Preprocessor::calculate_l2_norm(&data).unwrap();
 
     let norm_sequential = data.iter().map(|x| x * x).sum::<f64>().sqrt();
-    assert!((norm_parallel - norm_sequential).abs() < 1e-10);
+    assert_relative_eq!(norm_parallel, norm_sequential);
+}
+
+#[test]
+fn test_calculate_l2_norm_invalid_values() {
+    let cases = [
+        ("NaN", f64::NAN, "NaN"),
+        ("+Inf", f64::INFINITY, "Infinity"),
+        ("-Inf", f64::NEG_INFINITY, "Infinity"),
+    ];
+
+    for (label, bad_value, expected_fragment) in cases {
+        let data = vec![1.0, bad_value, 3.0];
+        let result = Preprocessor::calculate_l2_norm(&data);
+        assert!(
+            matches!(result, Err(MahoutError::InvalidInput(msg)) if 
msg.contains(expected_fragment)),
+            "case {label} did not produce expected error"
+        );
+    }
+}
+
+#[test]
+fn test_calculate_batch_l2_norms_invalid_values() {
+    let cases = [
+        ("NaN", f64::NAN, "NaN"),
+        ("+Inf", f64::INFINITY, "Infinity"),
+        ("-Inf", f64::NEG_INFINITY, "Infinity"),
+    ];
+
+    for (label, bad_value, expected_fragment) in cases {
+        let batch_data = vec![1.0, 2.0, bad_value, 4.0];
+        let result = Preprocessor::calculate_batch_l2_norms(&batch_data, 2, 2);
+        assert!(
+            matches!(result, Err(MahoutError::InvalidInput(msg)) if 
msg.contains(expected_fragment)),
+            "case {label} did not produce expected error"
+        );
+    }
+}
+
+#[test]
+fn test_calculate_batch_l2_norms_success() {
+    let batch_data = vec![3.0, 4.0, 5.0, 12.0];
+    let norms = Preprocessor::calculate_batch_l2_norms(&batch_data, 2, 
2).unwrap();
+    assert_eq!(norms.len(), 2);
+    assert_relative_eq!(norms[0], 5.0); // sqrt(3^2 + 4^2) = 5
+    assert_relative_eq!(norms[1], 13.0); // sqrt(5^2 + 12^2) = 13
 }

Reply via email to