This is an automated email from the ASF dual-hosted git repository.
guanmingchiu pushed a commit to branch dev-qdp
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/dev-qdp by this push:
new 038c9fd6d [QDP] Add pre-processing module (#673)
038c9fd6d is described below
commit 038c9fd6d5a1fc5db9dfd47a606801122220230d
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Mon Dec 1 23:45:16 2025 +0800
[QDP] Add pre-processing module (#673)
* add preprocessing module
* Remove unused import of Cow
---
qdp/qdp-core/src/gpu/encodings/amplitude.rs | 39 ++------------
qdp/qdp-core/src/gpu/encodings/angle.rs | 5 +-
qdp/qdp-core/src/gpu/encodings/mod.rs | 6 +++
qdp/qdp-core/src/lib.rs | 1 +
qdp/qdp-core/src/preprocessing.rs | 79 +++++++++++++++++++++++++++++
qdp/qdp-core/tests/preprocessing.rs | 73 ++++++++++++++++++++++++++
6 files changed, 166 insertions(+), 37 deletions(-)
diff --git a/qdp/qdp-core/src/gpu/encodings/amplitude.rs
b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
index fb61b1d3c..1709eab77 100644
--- a/qdp/qdp-core/src/gpu/encodings/amplitude.rs
+++ b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
@@ -18,7 +18,6 @@
use std::sync::Arc;
use cudarc::driver::CudaDevice;
-use rayon::prelude::*;
use crate::error::{MahoutError, Result};
use crate::gpu::memory::GpuStateVector;
use super::QuantumEncoder;
@@ -30,6 +29,8 @@ use cudarc::driver::DevicePtr;
#[cfg(target_os = "linux")]
use qdp_kernels::launch_amplitude_encode;
+use crate::preprocessing::Preprocessor;
+
/// Amplitude encoding: data → normalized quantum amplitudes
///
/// Steps: L2 norm (CPU) → GPU allocation → CUDA kernel (normalize + pad)
@@ -44,41 +45,9 @@ impl QuantumEncoder for AmplitudeEncoder {
num_qubits: usize,
) -> Result<GpuStateVector> {
// Validate qubits (max 30 = 16GB GPU memory)
- if num_qubits == 0 {
- return Err(MahoutError::InvalidInput(
- "Number of qubits must be at least 1".to_string()
- ));
- }
- if num_qubits > 30 {
- return Err(MahoutError::InvalidInput(
- format!("Number of qubits {} exceeds practical limit of 30",
num_qubits)
- ));
- }
-
- // Validate input data
- if host_data.is_empty() {
- return Err(MahoutError::InvalidInput(
- "Input data cannot be empty".to_string()
- ));
- }
-
+ Preprocessor::validate_input(host_data, num_qubits)?;
+ let norm = Preprocessor::calculate_l2_norm(host_data)?;
let state_len = 1 << num_qubits;
- if host_data.len() > state_len {
- return Err(MahoutError::InvalidInput(
- format!("Input data length {} exceeds state vector size {}",
host_data.len(), state_len)
- ));
- }
-
- // Calculate L2 norm (parallel on CPU for speed)
- let norm = {
- crate::profile_scope!("CPU::L2Norm");
- let norm_sq: f64 = host_data.par_iter().map(|x| x * x).sum();
- norm_sq.sqrt()
- };
-
- if norm == 0.0 {
- return Err(MahoutError::InvalidInput("Input data has zero
norm".to_string()));
- }
#[cfg(target_os = "linux")]
{
diff --git a/qdp/qdp-core/src/gpu/encodings/angle.rs
b/qdp/qdp-core/src/gpu/encodings/angle.rs
index b999578cf..0c2ed8c01 100644
--- a/qdp/qdp-core/src/gpu/encodings/angle.rs
+++ b/qdp/qdp-core/src/gpu/encodings/angle.rs
@@ -31,9 +31,10 @@ impl QuantumEncoder for AngleEncoder {
fn encode(
&self,
_device: &Arc<CudaDevice>,
- _data: &[f64],
- _num_qubits: usize,
+ data: &[f64],
+ num_qubits: usize,
) -> Result<GpuStateVector> {
+ self.validate_input(data, num_qubits)?;
Err(MahoutError::InvalidInput(
"Angle encoding not yet implemented. Use 'amplitude' encoding for
now.".to_string()
))
diff --git a/qdp/qdp-core/src/gpu/encodings/mod.rs
b/qdp/qdp-core/src/gpu/encodings/mod.rs
index c59b4dcb0..75cf57549 100644
--- a/qdp/qdp-core/src/gpu/encodings/mod.rs
+++ b/qdp/qdp-core/src/gpu/encodings/mod.rs
@@ -20,6 +20,7 @@ use std::sync::Arc;
use cudarc::driver::CudaDevice;
use crate::error::Result;
use crate::gpu::memory::GpuStateVector;
+use crate::preprocessing::Preprocessor;
/// Quantum encoding strategy interface
/// Implemented by: AmplitudeEncoder, AngleEncoder, BasisEncoder
@@ -32,6 +33,11 @@ pub trait QuantumEncoder: Send + Sync {
num_qubits: usize,
) -> Result<GpuStateVector>;
+ /// Validate input data before encoding
+ fn validate_input(&self, data: &[f64], num_qubits: usize) -> Result<()> {
+ Preprocessor::validate_input(data, num_qubits)
+ }
+
/// Strategy name
fn name(&self) -> &'static str;
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index 6634b15c4..6fa43af7e 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -17,6 +17,7 @@
pub mod dlpack;
pub mod gpu;
pub mod error;
+pub mod preprocessing;
#[macro_use]
mod profiling;
diff --git a/qdp/qdp-core/src/preprocessing.rs
b/qdp/qdp-core/src/preprocessing.rs
new file mode 100644
index 000000000..233548a0f
--- /dev/null
+++ b/qdp/qdp-core/src/preprocessing.rs
@@ -0,0 +1,79 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use rayon::prelude::*;
+use crate::error::{MahoutError, Result};
+
+/// Shared CPU-based pre-processing pipeline for quantum encoding.
+///
+/// Centralizes validation, normalization, and data preparation steps
+/// to ensure consistency across different encoding strategies and backends.
+pub struct Preprocessor;
+
+impl Preprocessor {
+ /// Validates standard quantum input constraints.
+ ///
+ /// Checks:
+ /// - Qubit count within practical limits (1-30)
+ /// - Data availability
+ /// - Data length against state vector size
+ pub fn validate_input(host_data: &[f64], num_qubits: usize) -> Result<()> {
+ // Validate qubits (max 30 = 16GB GPU memory)
+ if num_qubits == 0 {
+ return Err(MahoutError::InvalidInput(
+ "Number of qubits must be at least 1".to_string()
+ ));
+ }
+ if num_qubits > 30 {
+ return Err(MahoutError::InvalidInput(
+ format!("Number of qubits {} exceeds practical limit of 30",
num_qubits)
+ ));
+ }
+
+ // Validate input data
+ if host_data.is_empty() {
+ return Err(MahoutError::InvalidInput(
+ "Input data cannot be empty".to_string()
+ ));
+ }
+
+ let state_len = 1 << num_qubits;
+ if host_data.len() > state_len {
+ return Err(MahoutError::InvalidInput(
+ format!("Input data length {} exceeds state vector size {}",
host_data.len(), state_len)
+ ));
+ }
+
+ Ok(())
+ }
+
+ /// Calculates L2 norm of the input data in parallel on the CPU.
+ ///
+ /// Returns error if the calculated norm is zero.
+ pub fn calculate_l2_norm(host_data: &[f64]) -> Result<f64> {
+ let norm = {
+ crate::profile_scope!("CPU::L2Norm");
+ let norm_sq: f64 = host_data.par_iter().map(|x| x * x).sum();
+ norm_sq.sqrt()
+ };
+
+ if norm == 0.0 {
+ return Err(MahoutError::InvalidInput("Input data has zero
norm".to_string()));
+ }
+
+ Ok(norm)
+ }
+}
diff --git a/qdp/qdp-core/tests/preprocessing.rs
b/qdp/qdp-core/tests/preprocessing.rs
new file mode 100644
index 000000000..011837b46
--- /dev/null
+++ b/qdp/qdp-core/tests/preprocessing.rs
@@ -0,0 +1,73 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use qdp_core::preprocessing::Preprocessor;
+use qdp_core::MahoutError;
+
+#[test]
+fn test_validate_input_success() {
+ let data = vec![1.0, 0.0];
+ assert!(Preprocessor::validate_input(&data, 1).is_ok());
+
+ let data = vec![1.0, 0.0, 0.0, 0.0];
+ assert!(Preprocessor::validate_input(&data, 2).is_ok());
+}
+
+#[test]
+fn test_validate_input_zero_qubits() {
+ let data = vec![1.0];
+ let result = Preprocessor::validate_input(&data, 0);
+ assert!(matches!(result, Err(MahoutError::InvalidInput(msg)) if
msg.contains("at least 1")));
+}
+
+#[test]
+fn test_validate_input_too_many_qubits() {
+ let data = vec![1.0];
+ let result = Preprocessor::validate_input(&data, 31);
+ assert!(matches!(result, Err(MahoutError::InvalidInput(msg)) if
msg.contains("exceeds practical limit")));
+}
+
+#[test]
+fn test_validate_input_empty_data() {
+ let data: Vec<f64> = vec![];
+ let result = Preprocessor::validate_input(&data, 1);
+ assert!(matches!(result, Err(MahoutError::InvalidInput(msg)) if
msg.contains("cannot be empty")));
+}
+
+#[test]
+fn test_validate_input_data_too_large() {
+ let data = vec![1.0, 0.0, 0.0]; // 3 elements
+ let result = Preprocessor::validate_input(&data, 1); // max size 2^1 = 2
+ assert!(matches!(result, Err(MahoutError::InvalidInput(msg)) if
msg.contains("exceeds state vector size")));
+}
+
+#[test]
+fn test_calculate_l2_norm_success() {
+ let data = vec![3.0, 4.0];
+ let norm = Preprocessor::calculate_l2_norm(&data).unwrap();
+ assert!((norm - 5.0).abs() < 1e-10);
+
+ let data = vec![1.0, 1.0];
+ let norm = Preprocessor::calculate_l2_norm(&data).unwrap();
+ assert!((norm - 2.0_f64.sqrt()).abs() < 1e-10);
+}
+
+#[test]
+fn test_calculate_l2_norm_zero() {
+ let data = vec![0.0, 0.0, 0.0];
+ let result = Preprocessor::calculate_l2_norm(&data);
+ assert!(matches!(result, Err(MahoutError::InvalidInput(msg)) if
msg.contains("zero norm")));
+}