rok commented on code in PR #8162: URL: https://github.com/apache/arrow-rs/pull/8162#discussion_r2313399562
########## parquet/tests/encryption/encryption_async.rs: ########## @@ -493,8 +503,341 @@ async fn read_and_roundtrip_to_encrypted_file_async( verify_encryption_test_file_read_async(&mut file, decryption_properties).await } +// Type aliases for multithreaded file writing tests +type ColSender = Sender<ArrowLeafColumn>; +type ColumnWriterTask = JoinHandle<Result<ArrowColumnWriter, ParquetError>>; +type RBStreamSerializeResult = Result<(Vec<ArrowColumnChunk>, usize), ParquetError>; + +async fn send_arrays_to_column_writers( + col_array_channels: &[ColSender], + rb: &RecordBatch, + schema: &Arc<Schema>, +) -> Result<(), ParquetError> { + // Each leaf column has its own channel, increment next_channel for each leaf column sent. + let mut next_channel = 0; + for (array, field) in rb.columns().iter().zip(schema.fields()) { + for c in compute_leaves(field, array)? { + if col_array_channels[next_channel].send(c).await.is_err() { + return Ok(()); + } + next_channel += 1; + } + } + Ok(()) +} + +/// Spawns a tokio task which joins the parallel column writer tasks, +/// and finalizes the row group +fn spawn_rg_join_and_finalize_task( + column_writer_tasks: Vec<ColumnWriterTask>, + rg_rows: usize, +) -> JoinHandle<RBStreamSerializeResult> { + tokio::task::spawn(async move { + let num_cols = column_writer_tasks.len(); + let mut finalized_rg = Vec::with_capacity(num_cols); + for task in column_writer_tasks.into_iter() { + let writer = task + .await + .map_err(|e| ParquetError::General(e.to_string()))??; + finalized_rg.push(writer.close()?); + } + Ok((finalized_rg, rg_rows)) + }) +} + +fn spawn_parquet_parallel_serialization_task( + writer_factory: ArrowRowGroupWriterFactory, + mut data: Receiver<RecordBatch>, + serialize_tx: Sender<JoinHandle<RBStreamSerializeResult>>, + schema: Arc<Schema>, +) -> JoinHandle<Result<(), ParquetError>> { + tokio::spawn(async move { + let max_buffer_rb = 10; + let max_row_group_rows = 10; + let mut row_group_index = 0; + + let column_writers = writer_factory.create_column_writers(row_group_index)?; + + let (mut col_writer_tasks, mut col_array_channels) = + spawn_column_parallel_row_group_writer(column_writers, max_buffer_rb)?; + + let mut current_rg_rows = 0; + + while let Some(mut rb) = data.recv().await { + // This loop allows the "else" block to repeatedly split the RecordBatch to handle the case + // when max_row_group_rows < execution.batch_size as an alternative to a recursive async + // function. + loop { + if current_rg_rows + rb.num_rows() < max_row_group_rows { + send_arrays_to_column_writers(&col_array_channels, &rb, &schema).await?; + current_rg_rows += rb.num_rows(); + break; + } else { + let rows_left = max_row_group_rows - current_rg_rows; + let rb_split = rb.slice(0, rows_left); + send_arrays_to_column_writers(&col_array_channels, &rb_split, &schema).await?; + + // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup + // on a separate task, so that we can immediately start on the next RG before waiting + // for the current one to finish. + drop(col_array_channels); + + let finalize_rg_task = + spawn_rg_join_and_finalize_task(col_writer_tasks, max_row_group_rows); + + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } + + current_rg_rows = 0; + rb = rb.slice(rows_left, rb.num_rows() - rows_left); + + row_group_index += 1; + let column_writers = writer_factory.create_column_writers(row_group_index)?; + (col_writer_tasks, col_array_channels) = + spawn_column_parallel_row_group_writer(column_writers, 100)?; + } + } + } + + drop(col_array_channels); + // Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows + if current_rg_rows > 0 { + let finalize_rg_task = + spawn_rg_join_and_finalize_task(col_writer_tasks, current_rg_rows); + + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } + } + + Ok(()) + }) +} + +fn spawn_column_parallel_row_group_writer( + col_writers: Vec<ArrowColumnWriter>, + max_buffer_size: usize, +) -> Result<(Vec<ColumnWriterTask>, Vec<ColSender>), ParquetError> { + let num_columns = col_writers.len(); + + let mut col_writer_tasks = Vec::with_capacity(num_columns); + let mut col_array_channels = Vec::with_capacity(num_columns); + for mut col_writer in col_writers.into_iter() { + let (send_array, mut receive_array) = + tokio::sync::mpsc::channel::<ArrowLeafColumn>(max_buffer_size); + col_array_channels.push(send_array); + let handle = tokio::spawn(async move { + while let Some(col) = receive_array.recv().await { + col_writer.write(&col)?; + } + Ok(col_writer) + }); + col_writer_tasks.push(handle); + } + Ok((col_writer_tasks, col_array_channels)) +} + +/// Consume RowGroups serialized by other parallel tasks and concatenate them +/// to the final parquet file +async fn concatenate_parallel_row_groups<W: Write + Send>( + mut parquet_writer: SerializedFileWriter<W>, + mut serialize_rx: Receiver<JoinHandle<RBStreamSerializeResult>>, +) -> Result<FileMetaData, ParquetError> { + while let Some(task) = serialize_rx.recv().await { + let result = task.await; + let mut rg_out = parquet_writer.next_row_group()?; + let (serialized_columns, _cnt) = + result.map_err(|e| ParquetError::General(e.to_string()))??; + + for column_chunk in serialized_columns { + column_chunk.append_to_row_group(&mut rg_out)?; + } + rg_out.close()?; + } + + let file_metadata = parquet_writer.close()?; + Ok(file_metadata) +} + +// This test is based on DataFusion's ParquetSink. Motivation is to test +// concurrent writing of encrypted data over multiple row groups using the low-level API. +#[tokio::test] +async fn test_concurrent_encrypted_writing_over_multiple_row_groups() { + // Read example data and set up encryption/decryption properties + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/encrypt_columns_and_footer.parquet.encrypted"); + let file = std::fs::File::open(path).unwrap(); + + let file_encryption_properties = FileEncryptionProperties::builder(b"0123456789012345".into()) + .with_column_key("double_field", b"1234567890123450".into()) + .with_column_key("float_field", b"1234567890123451".into()) + .build() + .unwrap(); + let decryption_properties = FileDecryptionProperties::builder(b"0123456789012345".into()) + .with_column_key("double_field", b"1234567890123450".into()) + .with_column_key("float_field", b"1234567890123451".into()) + .build() + .unwrap(); + + let (record_batches, metadata) = + read_encrypted_file(&file, decryption_properties.clone()).unwrap(); + let schema = metadata.schema(); + + // Create a channel to send RecordBatches to the writer and send row groups + let (record_batch_tx, data) = tokio::sync::mpsc::channel::<RecordBatch>(100); + let data_generator = tokio::spawn(async move { + for record_batch in record_batches { + record_batch_tx.send(record_batch).await.unwrap(); + } + }); + + let props = Some( + WriterPropertiesBuilder::default() + .with_file_encryption_properties(file_encryption_properties) + .build(), + ); + + // Create a temporary file to write the encrypted data + let temp_file = tempfile::tempfile().unwrap(); + let arrow_writer = + ArrowWriter::try_new(&temp_file, metadata.schema().clone(), props.clone()).unwrap(); + + let (writer, row_group_writer_factory) = arrow_writer.into_serialized_writer().unwrap(); + let max_row_groups = 1; + + let (serialize_tx, serialize_rx) = + tokio::sync::mpsc::channel::<JoinHandle<RBStreamSerializeResult>>(max_row_groups); + + let launch_serialization_task = spawn_parquet_parallel_serialization_task( + row_group_writer_factory, + data, + serialize_tx, + schema.clone(), + ); + + let _file_metadata = concatenate_parallel_row_groups(writer, serialize_rx) + .await + .unwrap(); + + data_generator.await.unwrap(); + launch_serialization_task.await.unwrap().unwrap(); + + // Check that the file was written correctly + let (read_record_batches, read_metadata) = + read_encrypted_file(&temp_file, decryption_properties.clone()).unwrap(); + + assert_eq!(read_metadata.metadata().file_metadata().num_rows(), 50); + verify_encryption_test_data(read_record_batches, read_metadata.metadata()); +} + #[tokio::test] -async fn test_multi_threaded_encrypted_writing() { +async fn test_multi_threaded_encrypted_writing_replace_deprecated_api() { + // Read example data and set up encryption/decryption properties + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/encrypt_columns_and_footer.parquet.encrypted"); + let file = std::fs::File::open(path).unwrap(); + + let file_encryption_properties = FileEncryptionProperties::builder(b"0123456789012345".into()) + .with_column_key("double_field", b"1234567890123450".into()) + .with_column_key("float_field", b"1234567890123451".into()) + .build() + .unwrap(); + let decryption_properties = FileDecryptionProperties::builder(b"0123456789012345".into()) + .with_column_key("double_field", b"1234567890123450".into()) + .with_column_key("float_field", b"1234567890123451".into()) + .build() + .unwrap(); + + let (record_batches, metadata) = + read_encrypted_file(&file, decryption_properties.clone()).unwrap(); + let schema = metadata.schema().clone(); + + let props = Some( + WriterPropertiesBuilder::default() + .with_file_encryption_properties(file_encryption_properties) + .build(), + ); + + // Create a temporary file to write the encrypted data + let temp_file = tempfile::tempfile().unwrap(); + let writer = + ArrowWriter::try_new(&temp_file, metadata.schema().clone(), props.clone()).unwrap(); + + let (mut serialized_file_writer, row_group_writer_factory) = + writer.into_serialized_writer().unwrap(); + + let (serialize_tx, mut serialize_rx) = + tokio::sync::mpsc::channel::<JoinHandle<RBStreamSerializeResult>>(1); + + // Create a channel to send RecordBatches to the writer and send row batches + let (record_batch_tx, mut data) = tokio::sync::mpsc::channel::<RecordBatch>(100); + let data_generator = tokio::spawn(async move { + for record_batch in record_batches { + record_batch_tx.send(record_batch).await.unwrap(); + } + }); + + // Get column writers + // This is instead of let col_writers = writer.get_column_writers().unwrap(); Review Comment: Ok, removing this comment too: ```rust // This is instead of arrow_writer.append_row_group(arrow_column_chunks) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org