This is an automated email from the ASF dual-hosted git repository.

kevinjqliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 5773b7f1 perf: iterate over generators when writing datafiles to 
reduce memory pressure (#2671)
5773b7f1 is described below

commit 5773b7f1bf2081a90a490f9d670eef804eb88ab4
Author: Alex <[email protected]>
AuthorDate: Mon Nov 3 11:46:49 2025 -0700

    perf: iterate over generators when writing datafiles to reduce memory 
pressure (#2671)
    
    # Rationale for this change
    
    When writing to partitioned tables, there is a large memory spike when
    the partitions are computed because we `.combine_chunks()` on the new
    partitioned arrow tables and we materialize the entire list of
    partitions before writing data files.
    
    This PR switches the partition computation to a generator to avoid
    materializing all the partitions in memory at once, reducing the memory
    overhead of writing to partitioned tables.
    
    ## Are these changes tested?
    
    No new tests. The tests using this method were updated to consume the
    generator as a list.
    
    However, in my personal use case, I am using
    `pa.total_allocated_bytes()` to determine memory allocation before and
    after the write and see the following across 5 writes of ~128 MB:
    
    | Run | Original Impl (Before Write) | Original Impl (After Write) |
    Iters (Before Write) | Iters (After Write) |
    |---|---|---|---|---|
    | 1 | 29.31 MB | 151.62 MB | 28.38 MB | 30.40 MB |
    | 2 | 27.74 MB | 151.62 MB | 28.85 MB | 30.36 MB |
    | 3 | 28.81 MB | 151.62 MB | 28.52 MB | 31.29 MB |
    | 4 | 28.71 MB | 151.62 MB | 29.27 MB | 30.64 MB |
    | 5 | 28.60 MB | 151.61 MB | 28.29 MB | 31.11 MB |
    
    This scales with the size of the write: if I want to write a 3 GB arrow
    table to a partitioned table, I need at least 6 GB RAM.
    
    ## Are there any user-facing changes?
    
    No.
---
 pyiceberg/io/pyarrow.py  | 41 +++++++++++++++++------------------------
 tests/io/test_pyarrow.py |  8 ++++----
 2 files changed, 21 insertions(+), 28 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index e42c1307..7710df76 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -2790,11 +2790,9 @@ def _dataframe_to_data_files(
         yield from write_file(
             io=io,
             table_metadata=table_metadata,
-            tasks=iter(
-                [
-                    WriteTask(write_uuid=write_uuid, task_id=next(counter), 
record_batches=batches, schema=task_schema)
-                    for batches in bin_pack_arrow_table(df, target_file_size)
-                ]
+            tasks=(
+                WriteTask(write_uuid=write_uuid, task_id=next(counter), 
record_batches=batches, schema=task_schema)
+                for batches in bin_pack_arrow_table(df, target_file_size)
             ),
         )
     else:
@@ -2802,18 +2800,16 @@ def _dataframe_to_data_files(
         yield from write_file(
             io=io,
             table_metadata=table_metadata,
-            tasks=iter(
-                [
-                    WriteTask(
-                        write_uuid=write_uuid,
-                        task_id=next(counter),
-                        record_batches=batches,
-                        partition_key=partition.partition_key,
-                        schema=task_schema,
-                    )
-                    for partition in partitions
-                    for batches in 
bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
-                ]
+            tasks=(
+                WriteTask(
+                    write_uuid=write_uuid,
+                    task_id=next(counter),
+                    record_batches=batches,
+                    partition_key=partition.partition_key,
+                    schema=task_schema,
+                )
+                for partition in partitions
+                for batches in 
bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
             ),
         )
 
@@ -2824,7 +2820,7 @@ class _TablePartition:
     arrow_table_partition: pa.Table
 
 
-def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: 
pa.Table) -> List[_TablePartition]:
+def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: 
pa.Table) -> Iterable[_TablePartition]:
     """Based on the iceberg table partition spec, filter the arrow table into 
partitions with their keys.
 
     Example:
@@ -2852,8 +2848,6 @@ def _determine_partitions(spec: PartitionSpec, schema: 
Schema, arrow_table: pa.T
 
     unique_partition_fields = 
arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])
 
-    table_partitions = []
-    # TODO: As a next step, we could also play around with yielding instead of 
materializing the full list
     for unique_partition in unique_partition_fields.to_pylist():
         partition_key = PartitionKey(
             field_values=[
@@ -2880,12 +2874,11 @@ def _determine_partitions(spec: PartitionSpec, schema: 
Schema, arrow_table: pa.T
 
         # The combine_chunks seems to be counter-intuitive to do, but it 
actually returns
         # fresh buffers that don't interfere with each other when it is 
written out to file
-        table_partitions.append(
-            _TablePartition(partition_key=partition_key, 
arrow_table_partition=filtered_table.combine_chunks())
+        yield _TablePartition(
+            partition_key=partition_key,
+            arrow_table_partition=filtered_table.combine_chunks(),
         )
 
-    return table_partitions
-
 
 def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> 
pa.Array:
     """Get a field from an Arrow table, supporting both literal field names 
and nested field paths.
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index a19ddd60..45b9d9c9 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -2479,7 +2479,7 @@ def test_partition_for_demo() -> None:
         PartitionField(source_id=2, field_id=1002, 
transform=IdentityTransform(), name="n_legs_identity"),
         PartitionField(source_id=1, field_id=1001, 
transform=IdentityTransform(), name="year_identity"),
     )
-    result = _determine_partitions(partition_spec, test_schema, arrow_table)
+    result = list(_determine_partitions(partition_spec, test_schema, 
arrow_table))
     assert {table_partition.partition_key.partition for table_partition in 
result} == {
         Record(2, 2020),
         Record(100, 2021),
@@ -2518,7 +2518,7 @@ def test_partition_for_nested_field() -> None:
     ]
 
     arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
-    partitions = _determine_partitions(spec, schema, arrow_table)
+    partitions = list(_determine_partitions(spec, schema, arrow_table))
     partition_values = {p.partition_key.partition[0] for p in partitions}
 
     assert partition_values == {486729, 486730}
@@ -2550,7 +2550,7 @@ def test_partition_for_deep_nested_field() -> None:
     ]
 
     arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
-    partitions = _determine_partitions(spec, schema, arrow_table)
+    partitions = list(_determine_partitions(spec, schema, arrow_table))
 
     assert len(partitions) == 2  # 2 unique partitions
     partition_values = {p.partition_key.partition[0] for p in partitions}
@@ -2621,7 +2621,7 @@ def test_identity_partition_on_multi_columns() -> None:
         }
         arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
 
-        result = _determine_partitions(partition_spec, test_schema, 
arrow_table)
+        result = list(_determine_partitions(partition_spec, test_schema, 
arrow_table))
 
         assert {table_partition.partition_key.partition for table_partition in 
result} == expected
         concatenated_arrow_table = 
pa.concat_tables([table_partition.arrow_table_partition for table_partition in 
result])

Reply via email to