This is an automated email from the ASF dual-hosted git repository.
guanmingchiu pushed a commit to branch dev-qdp
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/dev-qdp by this push:
new ab2b6c90a [QDP] Add tests for pre-processing (#676)
ab2b6c90a is described below
commit ab2b6c90a9c5ff7e42c698b562f613cbdec085a8
Author: Ping <[email protected]>
AuthorDate: Wed Dec 3 00:39:36 2025 +0800
[QDP] Add tests for pre-processing (#676)
Signed-off-by: 400Ping <[email protected]>
---
qdp/qdp-core/tests/preprocessing.rs | 28 ++++++++++++++++++++++++++++
1 file changed, 28 insertions(+)
diff --git a/qdp/qdp-core/tests/preprocessing.rs
b/qdp/qdp-core/tests/preprocessing.rs
index 011837b46..cc7885943 100644
--- a/qdp/qdp-core/tests/preprocessing.rs
+++ b/qdp/qdp-core/tests/preprocessing.rs
@@ -54,6 +54,18 @@ fn test_validate_input_data_too_large() {
assert!(matches!(result, Err(MahoutError::InvalidInput(msg)) if
msg.contains("exceeds state vector size")));
}
+#[test]
+fn test_validate_input_allows_partial_state() {
+ let data = vec![0.5, -0.5, 0.25];
+ assert!(Preprocessor::validate_input(&data, 3).is_ok()); // state vector
can hold up to 8 elements
+}
+
+#[test]
+fn test_validate_input_max_qubits_boundary() {
+ let data = vec![1.0];
+ assert!(Preprocessor::validate_input(&data, 30).is_ok());
+}
+
#[test]
fn test_calculate_l2_norm_success() {
let data = vec![3.0, 4.0];
@@ -71,3 +83,19 @@ fn test_calculate_l2_norm_zero() {
let result = Preprocessor::calculate_l2_norm(&data);
assert!(matches!(result, Err(MahoutError::InvalidInput(msg)) if
msg.contains("zero norm")));
}
+
+#[test]
+fn test_calculate_l2_norm_mixed_signs() {
+ let data = vec![-3.0, 4.0];
+ let norm = Preprocessor::calculate_l2_norm(&data).unwrap();
+ assert!((norm - 5.0).abs() < 1e-10);
+}
+
+#[test]
+fn test_calculate_l2_norm_matches_sequential_sum() {
+ let data: Vec<f64> = (1..=1000).map(|v| v as f64).collect();
+ let norm_parallel = Preprocessor::calculate_l2_norm(&data).unwrap();
+
+ let norm_sequential = data.iter().map(|x| x * x).sum::<f64>().sqrt();
+ assert!((norm_parallel - norm_sequential).abs() < 1e-10);
+}