This is an automated email from the ASF dual-hosted git repository.
hcr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new 9e0aac162 Add unified and configurable null handling (#1101)
9e0aac162 is described below
commit 9e0aac1626d96b63b82e4f88862b08ac4fe8b609
Author: Ryan Huang <[email protected]>
AuthorDate: Tue Mar 3 07:31:48 2026 +0800
Add unified and configurable null handling (#1101)
* Add unified and configurable null handling
* Remove unused import of `Array` in I/O utilities
---
qdp/qdp-core/src/encoding/mod.rs | 3 +-
qdp/qdp-core/src/io.rs | 32 +++++-----
qdp/qdp-core/src/lib.rs | 1 +
qdp/qdp-core/src/pipeline_runner.rs | 29 +++++++--
qdp/qdp-core/src/reader.rs | 40 ++++++++++++
qdp/qdp-core/src/readers/arrow_ipc.rs | 19 +++---
qdp/qdp-core/src/readers/parquet.rs | 62 ++++++++++---------
qdp/qdp-core/tests/null_handling.rs | 79 ++++++++++++++++++++++++
qdp/qdp-python/qumat_qdp/loader.py | 12 ++++
qdp/qdp-python/src/engine.rs | 19 ++++--
qdp/qdp-python/src/loader.rs | 17 ++++-
qdp/qdp-python/tests/test_quantum_data_loader.py | 43 +++++++++++++
12 files changed, 287 insertions(+), 69 deletions(-)
diff --git a/qdp/qdp-core/src/encoding/mod.rs b/qdp/qdp-core/src/encoding/mod.rs
index 6e770ce94..a06548a9a 100644
--- a/qdp/qdp-core/src/encoding/mod.rs
+++ b/qdp/qdp-core/src/encoding/mod.rs
@@ -141,7 +141,8 @@ pub(crate) fn stream_encode<E: ChunkEncoder>(
encoder: E,
) -> Result<*mut DLManagedTensor> {
// Initialize reader
- let mut reader_core = crate::io::ParquetBlockReader::new(path, None)?;
+ let mut reader_core =
+ crate::io::ParquetBlockReader::new(path, None,
crate::reader::NullHandling::FillZero)?;
let num_samples = reader_core.total_rows;
// Allocate output state vector
diff --git a/qdp/qdp-core/src/io.rs b/qdp/qdp-core/src/io.rs
index 346e6a4ab..d27564a4e 100644
--- a/qdp/qdp-core/src/io.rs
+++ b/qdp/qdp-core/src/io.rs
@@ -26,37 +26,35 @@ use std::fs::File;
use std::path::Path;
use std::sync::Arc;
-use arrow::array::{Array, ArrayRef, Float64Array, RecordBatch};
+use arrow::array::{ArrayRef, Float64Array, RecordBatch};
use arrow::datatypes::{DataType, Field, Schema};
use parquet::arrow::ArrowWriter;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use parquet::file::properties::WriterProperties;
use crate::error::{MahoutError, Result};
+use crate::reader::{NullHandling, handle_float64_nulls};
/// Converts an Arrow Float64Array to Vec<f64>.
-pub fn arrow_to_vec(array: &Float64Array) -> Vec<f64> {
- if array.null_count() == 0 {
- array.values().to_vec()
- } else {
- array.iter().map(|opt| opt.unwrap_or(0.0)).collect()
- }
+pub fn arrow_to_vec(array: &Float64Array, null_handling: NullHandling) ->
Result<Vec<f64>> {
+ let mut result = Vec::with_capacity(array.len());
+ handle_float64_nulls(&mut result, array, null_handling)?;
+ Ok(result)
}
/// Flattens multiple Arrow Float64Arrays into a single Vec<f64>.
-pub fn arrow_to_vec_chunked(arrays: &[Float64Array]) -> Vec<f64> {
+pub fn arrow_to_vec_chunked(
+ arrays: &[Float64Array],
+ null_handling: NullHandling,
+) -> Result<Vec<f64>> {
let total_len: usize = arrays.iter().map(|a| a.len()).sum();
let mut result = Vec::with_capacity(total_len);
for array in arrays {
- if array.null_count() == 0 {
- result.extend_from_slice(array.values());
- } else {
- result.extend(array.iter().map(|opt| opt.unwrap_or(0.0)));
- }
+ handle_float64_nulls(&mut result, array, null_handling)?;
}
- result
+ Ok(result)
}
/// Reads Float64 data from a Parquet file.
@@ -64,7 +62,7 @@ pub fn arrow_to_vec_chunked(arrays: &[Float64Array]) ->
Vec<f64> {
/// Expects a single Float64 column. For zero-copy access, use
[`read_parquet_to_arrow`].
pub fn read_parquet<P: AsRef<Path>>(path: P) -> Result<Vec<f64>> {
let chunks = read_parquet_to_arrow(path)?;
- Ok(arrow_to_vec_chunked(&chunks))
+ arrow_to_vec_chunked(&chunks, NullHandling::FillZero)
}
/// Writes Float64 data to a Parquet file.
@@ -228,7 +226,7 @@ pub fn write_arrow_to_parquet<P: AsRef<Path>>(
/// Add OOM protection for very large files
pub fn read_parquet_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<f64>, usize,
usize)> {
use crate::reader::DataReader;
- let mut reader = crate::readers::ParquetReader::new(path, None)?;
+ let mut reader = crate::readers::ParquetReader::new(path, None,
NullHandling::FillZero)?;
reader.read_batch()
}
@@ -244,7 +242,7 @@ pub fn read_parquet_batch<P: AsRef<Path>>(path: P) ->
Result<(Vec<f64>, usize, u
/// Add OOM protection for very large files
pub fn read_arrow_ipc_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<f64>,
usize, usize)> {
use crate::reader::DataReader;
- let mut reader = crate::readers::ArrowIPCReader::new(path)?;
+ let mut reader = crate::readers::ArrowIPCReader::new(path,
NullHandling::FillZero)?;
reader.read_batch()
}
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index 65cea496e..ed9cba0b8 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -34,6 +34,7 @@ mod profiling;
pub use error::{MahoutError, Result, cuda_error_to_string};
pub use gpu::memory::Precision;
+pub use reader::{NullHandling, handle_float64_nulls};
// Throughput/latency pipeline runner: single path using QdpEngine and
encode_batch in Rust.
#[cfg(target_os = "linux")]
diff --git a/qdp/qdp-core/src/pipeline_runner.rs
b/qdp/qdp-core/src/pipeline_runner.rs
index 93d577b4b..9a41ee4bc 100644
--- a/qdp/qdp-core/src/pipeline_runner.rs
+++ b/qdp/qdp-core/src/pipeline_runner.rs
@@ -26,7 +26,7 @@ use crate::QdpEngine;
use crate::dlpack::DLManagedTensor;
use crate::error::{MahoutError, Result};
use crate::io;
-use crate::reader::StreamingDataReader;
+use crate::reader::{NullHandling, StreamingDataReader};
use crate::readers::ParquetStreamingReader;
/// Configuration for throughput/latency pipeline runs (Python
run_throughput_pipeline_py).
@@ -39,6 +39,7 @@ pub struct PipelineConfig {
pub encoding_method: String,
pub seed: Option<u64>,
pub warmup_batches: usize,
+ pub null_handling: NullHandling,
}
impl Default for PipelineConfig {
@@ -51,6 +52,7 @@ impl Default for PipelineConfig {
encoding_method: "amplitude".to_string(),
seed: None,
warmup_batches: 0,
+ null_handling: NullHandling::FillZero,
}
}
}
@@ -154,12 +156,23 @@ fn path_extension_lower(path: &Path) -> Option<String> {
/// Dispatches by path extension to the appropriate io reader. Returns (data,
num_samples, sample_size).
/// Unsupported or missing extension returns Err with message listing
supported formats.
-fn read_file_by_extension(path: &Path) -> Result<(Vec<f64>, usize, usize)> {
+fn read_file_by_extension(
+ path: &Path,
+ null_handling: NullHandling,
+) -> Result<(Vec<f64>, usize, usize)> {
let ext_lower = path_extension_lower(path);
let ext = ext_lower.as_deref();
match ext {
- Some("parquet") => io::read_parquet_batch(path),
- Some("arrow") | Some("feather") | Some("ipc") =>
io::read_arrow_ipc_batch(path),
+ Some("parquet") => {
+ use crate::reader::DataReader;
+ let mut reader = crate::readers::ParquetReader::new(path, None,
null_handling)?;
+ reader.read_batch()
+ }
+ Some("arrow") | Some("feather") | Some("ipc") => {
+ use crate::reader::DataReader;
+ let mut reader = crate::readers::ArrowIPCReader::new(path,
null_handling)?;
+ reader.read_batch()
+ }
Some("npy") => io::read_numpy_batch(path),
Some("pt") | Some("pth") => io::read_torch_batch(path),
Some("pb") => io::read_tensorflow_batch(path),
@@ -211,7 +224,7 @@ impl PipelineIterator {
batch_limit: usize,
) -> Result<Self> {
let path = path.as_ref();
- let (data, num_samples, sample_size) = read_file_by_extension(path)?;
+ let (data, num_samples, sample_size) = read_file_by_extension(path,
config.null_handling)?;
let vector_len = vector_len(config.num_qubits,
&config.encoding_method);
// Dimension validation at construction.
@@ -263,7 +276,11 @@ impl PipelineIterator {
)));
}
- let mut reader = ParquetStreamingReader::new(path,
Some(DEFAULT_PARQUET_ROW_GROUP_SIZE))?;
+ let mut reader = ParquetStreamingReader::new(
+ path,
+ Some(DEFAULT_PARQUET_ROW_GROUP_SIZE),
+ config.null_handling,
+ )?;
let vector_len = vector_len(config.num_qubits,
&config.encoding_method);
// Read first chunk to learn sample_size; reuse as initial buffer.
diff --git a/qdp/qdp-core/src/reader.rs b/qdp/qdp-core/src/reader.rs
index 81669c036..5ea3f3486 100644
--- a/qdp/qdp-core/src/reader.rs
+++ b/qdp/qdp-core/src/reader.rs
@@ -45,8 +45,48 @@
//! }
//! ```
+use arrow::array::{Array, Float64Array};
+
use crate::error::Result;
+/// Policy for handling null values in Float64 arrays.
+#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
+pub enum NullHandling {
+ /// Replace nulls with 0.0 (backward-compatible default).
+ #[default]
+ FillZero,
+ /// Return an error when a null is encountered.
+ Reject,
+}
+
+/// Append values from a `Float64Array` into `output`, applying the given null
policy.
+///
+/// When there are no nulls the fast path copies the underlying buffer
directly.
+pub fn handle_float64_nulls(
+ output: &mut Vec<f64>,
+ float_array: &Float64Array,
+ null_handling: NullHandling,
+) -> crate::error::Result<()> {
+ if float_array.null_count() == 0 {
+ output.extend_from_slice(float_array.values());
+ } else {
+ match null_handling {
+ NullHandling::FillZero => {
+ output.extend(float_array.iter().map(|opt|
opt.unwrap_or(0.0)));
+ }
+ NullHandling::Reject => {
+ return Err(crate::error::MahoutError::InvalidInput(
+ "Null value encountered in Float64Array. \
+ Use NullHandling::FillZero to replace nulls with 0.0, \
+ or clean the data at the source."
+ .to_string(),
+ ));
+ }
+ }
+ }
+ Ok(())
+}
+
/// Generic data reader interface for batch quantum data.
///
/// Implementations should read data in the format:
diff --git a/qdp/qdp-core/src/readers/arrow_ipc.rs
b/qdp/qdp-core/src/readers/arrow_ipc.rs
index f2c781e44..39d9d9148 100644
--- a/qdp/qdp-core/src/readers/arrow_ipc.rs
+++ b/qdp/qdp-core/src/readers/arrow_ipc.rs
@@ -24,12 +24,13 @@ use arrow::datatypes::DataType;
use arrow::ipc::reader::FileReader as ArrowFileReader;
use crate::error::{MahoutError, Result};
-use crate::reader::DataReader;
+use crate::reader::{DataReader, NullHandling, handle_float64_nulls};
/// Reader for Arrow IPC files containing FixedSizeList<Float64> or
List<Float64> columns.
pub struct ArrowIPCReader {
path: std::path::PathBuf,
read: bool,
+ null_handling: NullHandling,
}
impl ArrowIPCReader {
@@ -37,7 +38,8 @@ impl ArrowIPCReader {
///
/// # Arguments
/// * `path` - Path to the Arrow IPC file (.arrow or .feather)
- pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
+ /// * `null_handling` - Policy for null values (defaults to `FillZero`)
+ pub fn new<P: AsRef<Path>>(path: P, null_handling: NullHandling) ->
Result<Self> {
let path = path.as_ref();
// Verify file exists
@@ -64,6 +66,7 @@ impl ArrowIPCReader {
Ok(Self {
path: path.to_path_buf(),
read: false,
+ null_handling,
})
}
}
@@ -136,11 +139,7 @@ impl DataReader for ArrowIPCReader {
.downcast_ref::<Float64Array>()
.ok_or_else(|| MahoutError::Io("Values must be
Float64".to_string()))?;
- if float_array.null_count() == 0 {
- all_data.extend_from_slice(float_array.values());
- } else {
- all_data.extend(float_array.iter().map(|opt|
opt.unwrap_or(0.0)));
- }
+ handle_float64_nulls(&mut all_data, float_array,
self.null_handling)?;
num_samples += list_array.len();
}
@@ -182,11 +181,7 @@ impl DataReader for ArrowIPCReader {
all_data.reserve(new_capacity);
}
- if float_array.null_count() == 0 {
- all_data.extend_from_slice(float_array.values());
- } else {
- all_data.extend(float_array.iter().map(|opt|
opt.unwrap_or(0.0)));
- }
+ handle_float64_nulls(&mut all_data, float_array,
self.null_handling)?;
num_samples += 1;
}
diff --git a/qdp/qdp-core/src/readers/parquet.rs
b/qdp/qdp-core/src/readers/parquet.rs
index 91bb0007e..2bbda38a7 100644
--- a/qdp/qdp-core/src/readers/parquet.rs
+++ b/qdp/qdp-core/src/readers/parquet.rs
@@ -24,13 +24,14 @@ use arrow::datatypes::DataType;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use crate::error::{MahoutError, Result};
-use crate::reader::{DataReader, StreamingDataReader};
+use crate::reader::{DataReader, NullHandling, StreamingDataReader,
handle_float64_nulls};
/// Reader for Parquet files containing List<Float64> or
FixedSizeList<Float64> columns.
pub struct ParquetReader {
reader: Option<parquet::arrow::arrow_reader::ParquetRecordBatchReader>,
sample_size: Option<usize>,
total_rows: usize,
+ null_handling: NullHandling,
}
impl ParquetReader {
@@ -39,7 +40,12 @@ impl ParquetReader {
/// # Arguments
/// * `path` - Path to the Parquet file
/// * `batch_size` - Optional batch size for reading (defaults to entire
file)
- pub fn new<P: AsRef<Path>>(path: P, batch_size: Option<usize>) ->
Result<Self> {
+ /// * `null_handling` - Policy for null values (defaults to `FillZero`)
+ pub fn new<P: AsRef<Path>>(
+ path: P,
+ batch_size: Option<usize>,
+ null_handling: NullHandling,
+ ) -> Result<Self> {
let path = path.as_ref();
// Verify file exists
@@ -118,6 +124,7 @@ impl ParquetReader {
reader: Some(reader),
sample_size: None,
total_rows,
+ null_handling,
})
}
}
@@ -173,11 +180,7 @@ impl DataReader for ParquetReader {
all_data.reserve(current_size * self.total_rows);
}
- if float_array.null_count() == 0 {
- all_data.extend_from_slice(float_array.values());
- } else {
- all_data.extend(float_array.iter().map(|opt|
opt.unwrap_or(0.0)));
- }
+ handle_float64_nulls(&mut all_data, float_array,
self.null_handling)?;
num_samples += 1;
}
@@ -203,11 +206,7 @@ impl DataReader for ParquetReader {
.downcast_ref::<Float64Array>()
.ok_or_else(|| MahoutError::Io("Values must be
Float64".to_string()))?;
- if float_array.null_count() == 0 {
- all_data.extend_from_slice(float_array.values());
- } else {
- all_data.extend(float_array.iter().map(|opt|
opt.unwrap_or(0.0)));
- }
+ handle_float64_nulls(&mut all_data, float_array,
self.null_handling)?;
num_samples += list_array.len();
}
@@ -247,6 +246,7 @@ pub struct ParquetStreamingReader {
leftover_data: Vec<f64>,
leftover_cursor: usize,
pub total_rows: usize,
+ null_handling: NullHandling,
}
impl ParquetStreamingReader {
@@ -255,7 +255,12 @@ impl ParquetStreamingReader {
/// # Arguments
/// * `path` - Path to the Parquet file
/// * `batch_size` - Optional batch size (defaults to 2048)
- pub fn new<P: AsRef<Path>>(path: P, batch_size: Option<usize>) ->
Result<Self> {
+ /// * `null_handling` - Policy for null values (defaults to `FillZero`)
+ pub fn new<P: AsRef<Path>>(
+ path: P,
+ batch_size: Option<usize>,
+ null_handling: NullHandling,
+ ) -> Result<Self> {
let path = path.as_ref();
// Verify file exists
@@ -338,6 +343,7 @@ impl ParquetStreamingReader {
leftover_data: Vec::new(),
leftover_cursor: 0,
total_rows,
+ null_handling,
})
}
@@ -449,11 +455,11 @@ impl StreamingDataReader for ParquetStreamingReader {
current_sample_size =
Some(float_array.len());
}
- if float_array.null_count() == 0 {
-
batch_values.extend_from_slice(float_array.values());
- } else {
- return Err(MahoutError::Io("Null value
encountered in Float64Array during quantum encoding. Please check data quality
at the source.".to_string()));
- }
+ handle_float64_nulls(
+ &mut batch_values,
+ float_array,
+ self.null_handling,
+ )?;
}
(
@@ -489,11 +495,11 @@ impl StreamingDataReader for ParquetStreamingReader {
})?;
let mut batch_values = Vec::new();
- if float_array.null_count() == 0 {
-
batch_values.extend_from_slice(float_array.values());
- } else {
- return Err(MahoutError::Io("Null value
encountered in Float64Array during quantum encoding. Please check data quality
at the source.".to_string()));
- }
+ handle_float64_nulls(
+ &mut batch_values,
+ float_array,
+ self.null_handling,
+ )?;
(current_sample_size, batch_values)
}
@@ -515,11 +521,11 @@ impl StreamingDataReader for ParquetStreamingReader {
let current_sample_size = 1;
let mut batch_values = Vec::new();
- if float_array.null_count() == 0 {
-
batch_values.extend_from_slice(float_array.values());
- } else {
- return Err(MahoutError::Io("Null value
encountered in Float64Array during quantum encoding. Please check data quality
at the source.".to_string()));
- }
+ handle_float64_nulls(
+ &mut batch_values,
+ float_array,
+ self.null_handling,
+ )?;
(current_sample_size, batch_values)
}
diff --git a/qdp/qdp-core/tests/null_handling.rs
b/qdp/qdp-core/tests/null_handling.rs
new file mode 100644
index 000000000..6d61df410
--- /dev/null
+++ b/qdp/qdp-core/tests/null_handling.rs
@@ -0,0 +1,79 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Tests for the unified NullHandling policy.
+
+use arrow::array::Float64Array;
+use qdp_core::reader::{NullHandling, handle_float64_nulls};
+
+#[test]
+fn fill_zero_replaces_nulls() {
+ let array = Float64Array::from(vec![Some(1.0), None, Some(3.0), None]);
+ let mut output = Vec::new();
+ handle_float64_nulls(&mut output, &array, NullHandling::FillZero).unwrap();
+ assert_eq!(output, vec![1.0, 0.0, 3.0, 0.0]);
+}
+
+#[test]
+fn reject_returns_error_on_null() {
+ let array = Float64Array::from(vec![Some(1.0), None, Some(3.0)]);
+ let mut output = Vec::new();
+ let result = handle_float64_nulls(&mut output, &array,
NullHandling::Reject);
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("Null value encountered"),
+ "unexpected error: {}",
+ err_msg
+ );
+}
+
+#[test]
+fn no_nulls_fast_path() {
+ let array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
+ let mut output = Vec::new();
+
+ // Both policies should succeed and produce the same result when no nulls
present
+ handle_float64_nulls(&mut output, &array, NullHandling::FillZero).unwrap();
+ assert_eq!(output, vec![1.0, 2.0, 3.0, 4.0]);
+
+ let mut output2 = Vec::new();
+ handle_float64_nulls(&mut output2, &array, NullHandling::Reject).unwrap();
+ assert_eq!(output2, vec![1.0, 2.0, 3.0, 4.0]);
+}
+
+#[test]
+fn default_is_fill_zero() {
+ assert_eq!(NullHandling::default(), NullHandling::FillZero);
+}
+
+#[test]
+fn fill_zero_on_all_nulls() {
+ let array = Float64Array::from(vec![None, None, None]);
+ let mut output = Vec::new();
+ handle_float64_nulls(&mut output, &array, NullHandling::FillZero).unwrap();
+ assert_eq!(output, vec![0.0, 0.0, 0.0]);
+}
+
+#[test]
+fn empty_array_is_noop() {
+ let array = Float64Array::from(Vec::<f64>::new());
+ let mut output = Vec::new();
+ handle_float64_nulls(&mut output, &array, NullHandling::FillZero).unwrap();
+ assert!(output.is_empty());
+ handle_float64_nulls(&mut output, &array, NullHandling::Reject).unwrap();
+ assert!(output.is_empty());
+}
diff --git a/qdp/qdp-python/qumat_qdp/loader.py
b/qdp/qdp-python/qumat_qdp/loader.py
index 6c908325d..29d01863b 100644
--- a/qdp/qdp-python/qumat_qdp/loader.py
+++ b/qdp/qdp-python/qumat_qdp/loader.py
@@ -118,6 +118,7 @@ class QuantumDataLoader:
)
self._synthetic_requested = False # set True only by
source_synthetic()
self._file_requested = False
+ self._null_handling: Optional[str] = None
def qubits(self, n: int) -> QuantumDataLoader:
"""Set number of qubits. Returns self for chaining."""
@@ -190,6 +191,15 @@ class QuantumDataLoader:
self._seed = s
return self
+ def null_handling(self, policy: str) -> QuantumDataLoader:
+ """Set null handling policy ('fill_zero' or 'reject'). Returns self
for chaining."""
+ if policy not in ("fill_zero", "reject"):
+ raise ValueError(
+ f"null_handling must be 'fill_zero' or 'reject', got
{policy!r}"
+ )
+ self._null_handling = policy
+ return self
+
def _create_iterator(self) -> Iterator[object]:
"""Build engine and return the Rust-backed loader iterator (synthetic
or file)."""
if self._synthetic_requested and self._file_requested:
@@ -237,6 +247,7 @@ class QuantumDataLoader:
num_qubits=self._num_qubits,
encoding_method=self._encoding_method,
batch_limit=None,
+ null_handling=self._null_handling,
)
)
create_synthetic_loader = getattr(engine, "create_synthetic_loader",
None)
@@ -251,6 +262,7 @@ class QuantumDataLoader:
num_qubits=self._num_qubits,
encoding_method=self._encoding_method,
seed=self._seed,
+ null_handling=self._null_handling,
)
)
diff --git a/qdp/qdp-python/src/engine.rs b/qdp/qdp-python/src/engine.rs
index 29f56909a..2a768bdb2 100644
--- a/qdp/qdp-python/src/engine.rs
+++ b/qdp/qdp-python/src/engine.rs
@@ -26,7 +26,7 @@ use pyo3::prelude::*;
use qdp_core::{Precision, QdpEngine as CoreEngine};
#[cfg(target_os = "linux")]
-use crate::loader::{PyQuantumLoader, config_from_args, path_from_py};
+use crate::loader::{PyQuantumLoader, config_from_args, parse_null_handling,
path_from_py};
/// PyO3 wrapper for QdpEngine
///
@@ -575,7 +575,7 @@ impl QdpEngine {
// --- Loader factory methods (Linux only) ---
#[cfg(target_os = "linux")]
/// Create a synthetic-data pipeline iterator (for
QuantumDataLoader.source_synthetic()).
- #[pyo3(signature = (total_batches, batch_size, num_qubits,
encoding_method, seed=None))]
+ #[pyo3(signature = (total_batches, batch_size, num_qubits,
encoding_method, seed=None, null_handling=None))]
fn create_synthetic_loader(
&self,
total_batches: usize,
@@ -583,7 +583,9 @@ impl QdpEngine {
num_qubits: u32,
encoding_method: &str,
seed: Option<u64>,
+ null_handling: Option<&str>,
) -> PyResult<PyQuantumLoader> {
+ let nh = parse_null_handling(null_handling)?;
let config = config_from_args(
&self.engine,
batch_size,
@@ -591,6 +593,7 @@ impl QdpEngine {
encoding_method,
total_batches,
seed,
+ nh,
);
let iter =
qdp_core::PipelineIterator::new_synthetic(self.engine.clone(), config).map_err(
|e| PyRuntimeError::new_err(format!("create_synthetic_loader
failed: {}", e)),
@@ -600,7 +603,8 @@ impl QdpEngine {
#[cfg(target_os = "linux")]
/// Create a file-backed pipeline iterator (full read then batch; for
QuantumDataLoader.source_file(path)).
- #[pyo3(signature = (path, batch_size, num_qubits, encoding_method,
batch_limit=None))]
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (path, batch_size, num_qubits, encoding_method,
batch_limit=None, null_handling=None))]
fn create_file_loader(
&self,
py: Python<'_>,
@@ -609,9 +613,11 @@ impl QdpEngine {
num_qubits: u32,
encoding_method: &str,
batch_limit: Option<usize>,
+ null_handling: Option<&str>,
) -> PyResult<PyQuantumLoader> {
let path_str = path_from_py(path)?;
let batch_limit = batch_limit.unwrap_or(usize::MAX);
+ let nh = parse_null_handling(null_handling)?;
let config = config_from_args(
&self.engine,
batch_size,
@@ -619,6 +625,7 @@ impl QdpEngine {
encoding_method,
0,
None,
+ nh,
);
let engine = self.engine.clone();
let iter = py
@@ -636,7 +643,8 @@ impl QdpEngine {
#[cfg(target_os = "linux")]
/// Create a streaming Parquet pipeline iterator (for
QuantumDataLoader.source_file(path, streaming=True)).
- #[pyo3(signature = (path, batch_size, num_qubits, encoding_method,
batch_limit=None))]
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (path, batch_size, num_qubits, encoding_method,
batch_limit=None, null_handling=None))]
fn create_streaming_file_loader(
&self,
py: Python<'_>,
@@ -645,9 +653,11 @@ impl QdpEngine {
num_qubits: u32,
encoding_method: &str,
batch_limit: Option<usize>,
+ null_handling: Option<&str>,
) -> PyResult<PyQuantumLoader> {
let path_str = path_from_py(path)?;
let batch_limit = batch_limit.unwrap_or(usize::MAX);
+ let nh = parse_null_handling(null_handling)?;
let config = config_from_args(
&self.engine,
batch_size,
@@ -655,6 +665,7 @@ impl QdpEngine {
encoding_method,
0,
None,
+ nh,
);
let engine = self.engine.clone();
let iter = py
diff --git a/qdp/qdp-python/src/loader.rs b/qdp/qdp-python/src/loader.rs
index 2806b17af..87749d688 100644
--- a/qdp/qdp-python/src/loader.rs
+++ b/qdp/qdp-python/src/loader.rs
@@ -20,6 +20,7 @@ mod loader_impl {
use crate::tensor::QuantumTensor;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
+ use qdp_core::reader::NullHandling;
use qdp_core::{PipelineConfig, PipelineIterator, QdpEngine as CoreEngine};
/// Rust-backed iterator yielding one QuantumTensor per batch; used by
QuantumDataLoader.
@@ -70,6 +71,18 @@ mod loader_impl {
}
}
+ /// Parse a Python null-handling string into the Rust enum.
+ pub fn parse_null_handling(s: Option<&str>) -> PyResult<NullHandling> {
+ match s {
+ None | Some("fill_zero") => Ok(NullHandling::FillZero),
+ Some("reject") => Ok(NullHandling::Reject),
+ Some(other) => Err(pyo3::exceptions::PyValueError::new_err(format!(
+ "Invalid null_handling policy '{}'. Expected 'fill_zero' or
'reject'.",
+ other
+ ))),
+ }
+ }
+
/// Build PipelineConfig from Python args. device_id is 0 (engine does not
expose it); iterator uses engine clone with correct device.
pub fn config_from_args(
_engine: &CoreEngine,
@@ -78,6 +91,7 @@ mod loader_impl {
encoding_method: &str,
total_batches: usize,
seed: Option<u64>,
+ null_handling: NullHandling,
) -> PipelineConfig {
PipelineConfig {
device_id: 0,
@@ -87,6 +101,7 @@ mod loader_impl {
encoding_method: encoding_method.to_string(),
seed,
warmup_batches: 0,
+ null_handling,
}
}
@@ -100,4 +115,4 @@ mod loader_impl {
}
#[cfg(target_os = "linux")]
-pub use loader_impl::{PyQuantumLoader, config_from_args, path_from_py};
+pub use loader_impl::{PyQuantumLoader, config_from_args, parse_null_handling,
path_from_py};
diff --git a/qdp/qdp-python/tests/test_quantum_data_loader.py
b/qdp/qdp-python/tests/test_quantum_data_loader.py
index 5d5fb2005..e636489ed 100644
--- a/qdp/qdp-python/tests/test_quantum_data_loader.py
+++ b/qdp/qdp-python/tests/test_quantum_data_loader.py
@@ -141,3 +141,46 @@ def test_streaming_parquet_extension_ok():
# Iteration may raise RuntimeError (no CUDA) or fail on missing file; we
only check builder accepts.
assert loader._streaming_requested is True
assert loader._file_path == "/tmp/data.parquet"
+
+
+# --- NullHandling builder tests ---
+
+
[email protected](not _loader_available(), reason="QuantumDataLoader not
available")
+def test_null_handling_fill_zero():
+ """null_handling('fill_zero') sets the field correctly."""
+ loader = (
+ QuantumDataLoader(device_id=0)
+ .qubits(4)
+ .batches(10, size=4)
+ .null_handling("fill_zero")
+ )
+ assert loader._null_handling == "fill_zero"
+
+
[email protected](not _loader_available(), reason="QuantumDataLoader not
available")
+def test_null_handling_reject():
+ """null_handling('reject') sets the field correctly."""
+ loader = (
+ QuantumDataLoader(device_id=0)
+ .qubits(4)
+ .batches(10, size=4)
+ .null_handling("reject")
+ )
+ assert loader._null_handling == "reject"
+
+
[email protected](not _loader_available(), reason="QuantumDataLoader not
available")
+def test_null_handling_invalid_raises():
+ """null_handling with an invalid string raises ValueError."""
+ with pytest.raises(ValueError) as exc_info:
+ QuantumDataLoader(device_id=0).null_handling("invalid_policy")
+ msg = str(exc_info.value)
+ assert "fill_zero" in msg or "reject" in msg
+
+
[email protected](not _loader_available(), reason="QuantumDataLoader not
available")
+def test_null_handling_default_is_none():
+ """By default, _null_handling is None (Rust will use FillZero)."""
+ loader = QuantumDataLoader(device_id=0)
+ assert loader._null_handling is None