This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-go.git
The following commit(s) were added to refs/heads/main by this push:
new 4f010424 feat(table): add fanout partition writer and rolling data
writer (#524)
4f010424 is described below
commit 4f010424b4e25954d819dd28fdeb582139b3af1b
Author: Badal Prasad Singh <[email protected]>
AuthorDate: Wed Sep 24 01:30:03 2025 +0530
feat(table): add fanout partition writer and rolling data writer (#524)
# Partitioned Fanout Writer with Rolling Data File Support (Append Mode)
This PR completes the implementation of partitioned writing with support
for rolling data files in append mode. It enables efficient,
parallelized ingestion into partitioned tables while maintaining
manifest and snapshot correctness.
**Slack Thread Discussion**:
[Link](https://apache-iceberg.slack.com/archives/C05J3MJ42BD/p1751002533414969)
**Proposal Document**: [Google
Drive](https://drive.google.com/file/d/18CwR9nhwkThs-Q-JZZvisBEaDICvp5Z7/view?usp=drive_link)
### Details
* Introduced parallel processing of `arrow.Record` using a user-defined
number of goroutines.
* Each goroutine maintains its own hash table to map partition keys to
row indices.
* After partitioning, `compute.Take()` is used to extract per-partition
data slices.
* Integrated dedicated rolling writers per partition to manage data file
size thresholds and output constraints.
### Tests Performed
* [x] Compatible with all partition transforms
* [x] Handled null values in partition columns
* [x] Validated compatibility with partition spec evolution
* [x] Verified correctness for non-linear transformation cases
* [x] Confirmed schema evolution compatibility
* [x] Partition pruning verified
---
@zeroshade — would appreciate your review when you get a chance!
---------
Signed-off-by: badalprasadsingh <[email protected]>
Co-authored-by: Matt Topol <[email protected]>
---
manifest.go | 294 ++++++++++++++++++----
manifest_test.go | 4 +
schema_conversions.go | 32 +--
schema_conversions_test.go | 149 ++++++++++++
table/arrow_utils.go | 31 ++-
table/arrow_utils_internal_test.go | 2 +-
table/internal/interfaces.go | 2 +-
table/internal/parquet_files.go | 4 +-
table/internal/parquet_files_test.go | 2 +-
table/internal/utils.go | 44 +++-
table/partitioned_fanout_writer.go | 333 +++++++++++++++++++++++++
table/partitioned_fanout_writer_test.go | 417 ++++++++++++++++++++++++++++++++
table/rolling_data_writer.go | 237 ++++++++++++++++++
table/snapshots_internal_test.go | 6 +-
table/writer.go | 14 +-
transforms_test.go | 1 +
16 files changed, 1494 insertions(+), 78 deletions(-)
diff --git a/manifest.go b/manifest.go
index 0f2483b0..2f0e829a 100644
--- a/manifest.go
+++ b/manifest.go
@@ -23,12 +23,14 @@ import (
"fmt"
"io"
"math"
+ "math/big"
"reflect"
"slices"
"strconv"
"sync"
"time"
+ "github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/iceberg-go/internal"
iceio "github.com/apache/iceberg-go/io"
"github.com/google/uuid"
@@ -404,7 +406,7 @@ func (m *manifestFile) FetchEntries(fs iceio.IO,
discardDeleted bool) ([]Manifes
return fetchManifestEntries(m, fs, discardDeleted)
}
-func getFieldIDMap(sc avro.Schema) (map[string]int, map[int]avro.LogicalType) {
+func getFieldIDMap(sc avro.Schema) (map[string]int, map[int]avro.LogicalType,
map[int]int) {
getField := func(rs *avro.RecordSchema, name string) *avro.Field {
for _, f := range rs.Fields() {
if f.Name() == name {
@@ -417,30 +419,45 @@ func getFieldIDMap(sc avro.Schema) (map[string]int,
map[int]avro.LogicalType) {
result := make(map[string]int)
logicalTypes := make(map[int]avro.LogicalType)
+ fixedSizes := make(map[int]int)
+
entryField := getField(sc.(*avro.RecordSchema), "data_file")
partitionField := getField(entryField.Type().(*avro.RecordSchema),
"partition")
for _, field := range
partitionField.Type().(*avro.RecordSchema).Fields() {
- if fid, ok := field.Prop("field-id").(float64); ok {
- result[field.Name()] = int(fid)
- avroTyp := field.Type()
- if us, ok := avroTyp.(*avro.UnionSchema); ok {
- for _, t := range us.Types() {
- avroTyp = t
- }
- }
- if ps, ok := avroTyp.(*avro.PrimitiveSchema); ok &&
ps.Logical() != nil {
- logicalTypes[int(fid)] = ps.Logical().Type()
+ var fid int
+ switch v := field.Prop("field-id").(type) {
+ case int:
+ fid = v
+ case float64:
+ fid = int(v)
+ default:
+ continue
+ }
+
+ result[field.Name()] = fid
+ avroTyp := field.Type()
+ if us, ok := avroTyp.(*avro.UnionSchema); ok {
+ typeList := us.Types()
+ avroTyp = typeList[len(typeList)-1]
+ }
+ if ps, ok := avroTyp.(*avro.PrimitiveSchema); ok &&
ps.Logical() != nil {
+ logicalTypes[fid] = ps.Logical().Type()
+ } else if fs, ok := avroTyp.(*avro.FixedSchema); ok &&
fs.Logical() != nil {
+ logicalTypes[int(fid)] = fs.Logical().Type()
+ if decimalLogical, ok :=
fs.Logical().(*avro.DecimalLogicalSchema); ok {
+ fixedSizes[int(fid)] = decimalLogical.Scale()
}
}
}
- return result, logicalTypes
+ return result, logicalTypes, fixedSizes
}
type hasFieldToIDMap interface {
setFieldNameToIDMap(map[string]int)
setFieldIDToLogicalTypeMap(map[int]avro.LogicalType)
+ setFieldIDToFixedSizeMap(map[int]int)
}
func fetchManifestEntries(m ManifestFile, fs iceio.IO, discardDeleted bool)
([]ManifestEntry, error) {
@@ -570,6 +587,7 @@ type ManifestReader struct {
content ManifestContent
fieldNameToID map[string]int
fieldIDToType map[int]avro.LogicalType
+ fieldIDToSize map[int]int
// The rest are lazily populated, on demand. Most readers
// will likely only try to load the entries.
@@ -634,7 +652,7 @@ func NewManifestReader(file ManifestFile, in io.Reader)
(*ManifestReader, error)
}
}
}
- fieldNameToID, fieldIDToType := getFieldIDMap(sc)
+ fieldNameToID, fieldIDToType, fieldIDToSize := getFieldIDMap(sc)
return &ManifestReader{
dec: dec,
@@ -644,6 +662,7 @@ func NewManifestReader(file ManifestFile, in io.Reader)
(*ManifestReader, error)
content: content,
fieldNameToID: fieldNameToID,
fieldIDToType: fieldIDToType,
+ fieldIDToSize: fieldIDToSize,
}, nil
}
@@ -743,6 +762,7 @@ func (c *ManifestReader) ReadEntry() (ManifestEntry, error)
{
if fieldToIDMap, ok := tmp.DataFile().(hasFieldToIDMap); ok {
fieldToIDMap.setFieldNameToIDMap(c.fieldNameToID)
fieldToIDMap.setFieldIDToLogicalTypeMap(c.fieldIDToType)
+ fieldToIDMap.setFieldIDToFixedSizeMap(c.fieldIDToSize)
}
return tmp, nil
@@ -874,6 +894,8 @@ type partitionFieldStats[T LiteralType] struct {
func newPartitionFieldStat(typ PrimitiveType) (fieldStats, error) {
switch typ.(type) {
+ case BooleanType:
+ return &partitionFieldStats[bool]{cmp: getComparator[bool]()},
nil
case Int32Type:
return &partitionFieldStats[int32]{cmp:
getComparator[int32]()}, nil
case Int64Type:
@@ -890,6 +912,8 @@ func newPartitionFieldStat(typ PrimitiveType) (fieldStats,
error) {
return &partitionFieldStats[Time]{cmp: getComparator[Time]()},
nil
case TimestampType:
return &partitionFieldStats[Timestamp]{cmp:
getComparator[Timestamp]()}, nil
+ case TimestampTzType:
+ return &partitionFieldStats[Timestamp]{cmp:
getComparator[Timestamp]()}, nil
case UUIDType:
return &partitionFieldStats[uuid.UUID]{cmp:
getComparator[uuid.UUID]()}, nil
case BinaryType:
@@ -1017,6 +1041,9 @@ type ManifestWriter struct {
spec PartitionSpec
schema *Schema
+ partFieldNameToID map[string]int
+ partFieldIDToType map[int]avro.LogicalType
+
snapshotID int64
addedFiles int32
addedRows int64
@@ -1054,15 +1081,19 @@ func NewManifestWriter(version int, out io.Writer, spec
PartitionSpec, schema *S
return nil, err
}
+ nameToID, idToType, _ := getFieldIDMap(fileSchema)
+
w := &ManifestWriter{
- impl: impl,
- version: version,
- output: out,
- spec: spec,
- schema: schema,
- snapshotID: snapshotID,
- minSeqNum: -1,
- partitions: make([]map[int]any, 0),
+ impl: impl,
+ version: version,
+ output: out,
+ spec: spec,
+ schema: schema,
+ partFieldNameToID: nameToID,
+ partFieldIDToType: idToType,
+ snapshotID: snapshotID,
+ minSeqNum: -1,
+ partitions: make([]map[int]any, 0),
}
md, err := w.meta()
@@ -1175,7 +1206,28 @@ func (w *ManifestWriter) addEntry(entry *manifestEntry)
error {
return fmt.Errorf("unknown entry status: %v", entry.Status())
}
- w.partitions = append(w.partitions, entry.DataFile().Partition())
+ if setter, ok := entry.DataFile().(hasFieldToIDMap); ok {
+ setter.setFieldNameToIDMap(w.partFieldNameToID)
+ setter.setFieldIDToLogicalTypeMap(w.partFieldIDToType)
+ }
+
+ w.partitions = append(w.partitions, entry.Data.Partition())
+ partitionData := avroPartitionData(entry.Data.Partition(),
w.partFieldIDToType)
+
+ if dataFile, ok := entry.DataFile().(*dataFile); ok {
+ convertedPartitionData := make(map[string]any)
+ for fieldID, convertedValue := range partitionData {
+ for fieldName, id := range w.partFieldNameToID {
+ if id == fieldID {
+ convertedPartitionData[fieldName] =
convertedValue
+
+ break
+ }
+ }
+ }
+ dataFile.PartitionData = convertedPartitionData
+ }
+
if (entry.Status() == EntryStatusADDED || entry.Status() ==
EntryStatusEXISTING) &&
entry.SequenceNum() > 0 && (w.minSeqNum < 0 ||
entry.SequenceNum() < w.minSeqNum) {
w.minSeqNum = entry.SequenceNum()
@@ -1505,29 +1557,115 @@ func avroPartitionData(input map[int]any, logicalTypes
map[int]avro.LogicalType)
out := make(map[int]any)
for k, v := range input {
if logical, ok := logicalTypes[k]; ok {
- switch logical {
- case avro.Date:
- out[k] =
Date(v.(time.Time).Truncate(24*time.Hour).Unix() / int64((time.Hour *
24).Seconds()))
- case avro.TimeMillis:
- out[k] = Time(v.(time.Duration).Milliseconds())
- case avro.TimeMicros:
- out[k] = Time(v.(time.Duration).Microseconds())
- case avro.TimestampMillis:
- out[k] =
Timestamp(v.(time.Time).UTC().UnixMilli())
- case avro.TimestampMicros:
- out[k] =
Timestamp(v.(time.Time).UTC().UnixMicro())
- default:
- out[k] = v
- }
-
- continue
+ out[k] = convertLogicalTypeValue(v, logical)
+ } else {
+ out[k] = v
}
- out[k] = v
}
return out
}
+func convertLogicalTypeValue(v any, logicalType avro.LogicalType) any {
+ switch logicalType {
+ case avro.Date:
+ return convertDateValue(v)
+ case avro.TimeMicros:
+ return convertTimeMicrosValue(v)
+ case avro.TimestampMicros:
+ return convertTimestampMicrosValue(v)
+ case avro.Decimal:
+ return convertDecimalValue(v)
+ case avro.UUID:
+ return convertUUIDValue(v)
+ default:
+ return v
+ }
+}
+
+func convertDateValue(v any) any {
+ if v == nil {
+ return map[string]any{"null": nil}
+ }
+
+ if d, ok := v.(Date); ok {
+ return map[string]any{"int.date": int32(d)}
+ }
+
+ return v
+}
+
+func convertTimeMicrosValue(v any) any {
+ if v == nil {
+ return map[string]any{"null": nil}
+ }
+
+ if t, ok := v.(Time); ok {
+ return map[string]any{"long.time-micros": int64(t)}
+ }
+
+ return v
+}
+
+func convertTimestampMicrosValue(v any) any {
+ if v == nil {
+ return map[string]any{"null": nil}
+ }
+
+ if ts, ok := v.(Timestamp); ok {
+ return map[string]any{"long.timestamp-micros": int64(ts)}
+ }
+
+ return v
+}
+
+func convertDecimalValue(v any) any {
+ if v == nil {
+ return map[string]any{"null": nil}
+ }
+
+ if dec, ok := v.(Decimal); ok {
+ fixedSize := internal.DecimalRequiredBytes(len(dec.String()))
+ bytes, err := DecimalLiteral(dec).MarshalBinary()
+ if err != nil {
+ return v
+ }
+ fixedArray := convertToFixedArray(padOrTruncateBytes(bytes,
fixedSize), fixedSize)
+
+ return map[string]any{"fixed": fixedArray}
+ }
+
+ return v
+}
+
+func convertUUIDValue(v any) any {
+ if v == nil {
+ return map[string]any{"null": nil}
+ }
+
+ if uuidVal, ok := v.(uuid.UUID); ok {
+ return map[string]any{"uuid": [16]byte(uuidVal)}
+ }
+
+ return v
+}
+
+func padOrTruncateBytes(bytes []byte, size int) []byte {
+ if len(bytes) >= size {
+ return bytes[len(bytes)-size:]
+ }
+ padded := slices.Grow(bytes, size-len(bytes))
+
+ return append(make([]byte, size-len(bytes)), padded...)
+}
+
+func convertToFixedArray(bytes []byte, size int) any {
+ arr := reflect.New(reflect.ArrayOf(size,
reflect.TypeOf(byte(0)))).Elem()
+ reflect.Copy(arr, reflect.ValueOf(bytes))
+
+ return arr.Interface()
+}
+
type dataFile struct {
Content ManifestEntryContent `avro:"content"`
Path string `avro:"file_path"`
@@ -1564,6 +1702,7 @@ type dataFile struct {
fieldNameToID map[string]int
fieldIDToLogicalType map[int]avro.LogicalType
fieldIDToPartitionData map[int]any
+ fieldIDToFixedSize map[int]int
specID int32
initMaps sync.Once
@@ -1583,18 +1722,87 @@ func (d *dataFile) initializeMapData() {
d.fieldIDToPartitionData = make(map[int]any,
len(d.PartitionData))
for k, v := range d.PartitionData {
if id, ok := d.fieldNameToID[k]; ok {
- d.fieldIDToPartitionData[id] = v
+ convertedValue :=
d.convertAvroValueToIcebergType(v, id)
+ d.fieldIDToPartitionData[id] =
convertedValue
}
}
}
- d.fieldIDToPartitionData =
avroPartitionData(d.fieldIDToPartitionData, d.fieldIDToLogicalType)
})
}
+func (d *dataFile) convertAvroValueToIcebergType(v any, fieldID int) any {
+ if logicalType, ok := d.fieldIDToLogicalType[fieldID]; ok {
+ switch logicalType {
+ case avro.Date:
+ if val, ok := v.(time.Time); ok {
+ return Date(val.Truncate(24*time.Hour).Unix() /
int64((time.Hour * 24).Seconds()))
+ }
+
+ return Date(v.(int32))
+ case avro.TimeMillis:
+ if val, ok := v.(time.Duration); ok {
+ return Time(val.Milliseconds())
+ }
+
+ return Time(v.(int64))
+ case avro.TimeMicros:
+ if val, ok := v.(time.Duration); ok {
+ return Time(val.Microseconds())
+ }
+
+ return Time(v.(int64))
+ case avro.TimestampMillis:
+ if val, ok := v.(time.Time); ok {
+ return Timestamp(val.UTC().UnixMilli())
+ }
+
+ return Timestamp(v.(int64))
+ case avro.TimestampMicros:
+ if val, ok := v.(time.Time); ok {
+ return Timestamp(val.UTC().UnixMicro())
+ }
+
+ return Timestamp(v.(int64))
+ case avro.Decimal:
+ if unionMap, ok := v.(map[string]interface{}); ok {
+ if val, ok := unionMap["fixed"]; ok {
+ if bigRatValue, ok := val.(*big.Rat);
ok {
+ scale :=
d.fieldIDToFixedSize[fieldID]
+ scaleFactor :=
new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(scale)), nil)
+ unscaled :=
new(big.Int).Mul(bigRatValue.Num(), scaleFactor)
+ unscaled =
unscaled.Div(unscaled, bigRatValue.Denom())
+ decimal128Val :=
decimal128.FromBigInt(unscaled)
+
+ return DecimalLiteral{
+ Scale: scale,
+ Val: decimal128Val,
+ }
+ }
+ }
+ }
+
+ return v
+ case avro.UUID:
+ if unionMap, ok := v.(map[string]interface{}); ok {
+ if val, ok := unionMap["uuid"]; ok {
+ if uuidArr, ok := val.([16]byte); ok {
+ return uuid.UUID(uuidArr)
+ }
+ }
+ }
+
+ return v
+ }
+ }
+
+ return v
+}
+
func (d *dataFile) setFieldNameToIDMap(m map[string]int) { d.fieldNameToID = m
}
func (d *dataFile) setFieldIDToLogicalTypeMap(m map[int]avro.LogicalType) {
d.fieldIDToLogicalType = m
}
+func (d *dataFile) setFieldIDToFixedSizeMap(m map[int]int) {
d.fieldIDToFixedSize = m }
func (d *dataFile) ContentType() ManifestEntryContent { return d.Content }
func (d *dataFile) FilePath() string { return d.Path }
@@ -1815,6 +2023,8 @@ func NewDataFileBuilder(
path string,
format FileFormat,
fieldIDToPartitionData map[int]any,
+ fieldIDToLogicalType map[int]avro.LogicalType,
+ fieldIDToFixedSize map[int]int,
recordCount int64,
fileSize int64,
) (*DataFileBuilder, error) {
@@ -1863,6 +2073,8 @@ func NewDataFileBuilder(
specID: int32(spec.id),
fieldIDToPartitionData: fieldIDToPartitionData,
fieldNameToID: fieldNameToID,
+ fieldIDToLogicalType: fieldIDToLogicalType,
+ fieldIDToFixedSize: fieldIDToFixedSize,
},
}, nil
}
diff --git a/manifest_test.go b/manifest_test.go
index 3e58906b..7927cc57 100644
--- a/manifest_test.go
+++ b/manifest_test.go
@@ -815,6 +815,8 @@ func (m *ManifestTestSuite)
TestReadManifestIncompleteSchema() {
"s3://bucket/namespace/table/data/abcd-0123.parquet",
ParquetFile,
map[int]any{},
+ map[int]avro.LogicalType{},
+ map[int]int{},
100,
100*1000*1000,
)
@@ -1079,6 +1081,8 @@ func (m *ManifestTestSuite) TestManifestEntryBuilder() {
"sample.parquet",
ParquetFile,
map[int]any{1001: int(1), 1002: time.Unix(1925, 0).UnixMicro()},
+ map[int]avro.LogicalType{},
+ map[int]int{},
1,
2,
)
diff --git a/schema_conversions.go b/schema_conversions.go
index 05b6879d..f51ca7bd 100644
--- a/schema_conversions.go
+++ b/schema_conversions.go
@@ -30,33 +30,37 @@ func partitionTypeToAvroSchema(t *StructType) (avro.Schema,
error) {
var sc avro.Schema
switch typ := f.Type.(type) {
case Int32Type:
- sc = internal.IntSchema
+ sc = internal.NullableSchema(internal.IntSchema)
case Int64Type:
- sc = internal.LongSchema
+ sc = internal.NullableSchema(internal.LongSchema)
case Float32Type:
- sc = internal.FloatSchema
+ sc = internal.NullableSchema(internal.FloatSchema)
case Float64Type:
- sc = internal.DoubleSchema
+ sc = internal.NullableSchema(internal.DoubleSchema)
case StringType:
- sc = internal.StringSchema
+ sc = internal.NullableSchema(internal.StringSchema)
case DateType:
- sc = internal.DateSchema
+ sc = internal.NullableSchema(internal.DateSchema)
case TimeType:
- sc = internal.TimeSchema
+ sc = internal.NullableSchema(internal.TimeSchema)
case TimestampType:
- sc = internal.TimestampSchema
+ sc = internal.NullableSchema(internal.TimestampSchema)
case TimestampTzType:
- sc = internal.TimestampTzSchema
+ sc = internal.NullableSchema(internal.TimestampTzSchema)
case UUIDType:
- sc = internal.UUIDSchema
+ sc = internal.NullableSchema(internal.UUIDSchema)
case BooleanType:
- sc = internal.BoolSchema
+ sc = internal.NullableSchema(internal.BoolSchema)
case BinaryType:
- sc = internal.BinarySchema
+ sc = internal.NullableSchema(internal.BinarySchema)
case FixedType:
- sc = internal.Must(avro.NewFixedSchema("fixed", "",
typ.len, nil))
+ // Currently the hamba/avro library couldn't resolve
the [n]byte array types for fixed schemas in unions.
+ // https://github.com/hamba/avro/issues/571
+ // TODO: Create the proper Fixed Schema for Avro that
can match the use case
+ sc = internal.NullableSchema(internal.BinarySchema)
case DecimalType:
- sc = internal.DecimalSchema(typ.precision, typ.scale)
+ decimalSchema := internal.DecimalSchema(typ.precision,
typ.scale)
+ sc = internal.NullableSchema(decimalSchema)
default:
return nil, fmt.Errorf("unsupported partition type:
%s", f.Type.String())
}
diff --git a/schema_conversions_test.go b/schema_conversions_test.go
new file mode 100644
index 00000000..9b86e940
--- /dev/null
+++ b/schema_conversions_test.go
@@ -0,0 +1,149 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package iceberg
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/apache/iceberg-go/internal"
+ "github.com/hamba/avro/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func partitionTypeToAvroSchemaNonNullable(t *StructType) (avro.Schema, error) {
+ fields := make([]*avro.Field, len(t.FieldList))
+ for i, f := range t.FieldList {
+ var sc avro.Schema
+ switch typ := f.Type.(type) {
+ case Int32Type:
+ sc = internal.IntSchema
+ case Int64Type:
+ sc = internal.LongSchema
+ case Float32Type:
+ sc = internal.FloatSchema
+ case Float64Type:
+ sc = internal.DoubleSchema
+ case StringType:
+ sc = internal.StringSchema
+ case DateType:
+ sc = internal.DateSchema
+ case TimeType:
+ sc = internal.TimeSchema
+ case TimestampType:
+ sc = internal.TimestampSchema
+ case TimestampTzType:
+ sc = internal.TimestampTzSchema
+ case UUIDType:
+ sc = internal.UUIDSchema
+ case BooleanType:
+ sc = internal.BoolSchema
+ case BinaryType:
+ sc = internal.BinarySchema
+ case FixedType:
+ sc = internal.BinarySchema
+ case DecimalType:
+ decimalSchema := internal.DecimalSchema(typ.precision,
typ.scale)
+ sc = decimalSchema
+ default:
+ return nil, fmt.Errorf("unsupported partition type:
%s", f.Type.String())
+ }
+
+ fields[i], _ = avro.NewField(f.Name, sc,
internal.WithFieldID(f.ID))
+ }
+
+ return avro.NewRecordSchema("r102", "", fields)
+}
+
+func TestPartitionTypeToAvroSchemaNullableAndNonNullable(t *testing.T) {
+ partitionType := &StructType{
+ FieldList: []NestedField{
+ {ID: 1, Name: "int32_col", Type: Int32Type{}, Required:
false},
+ {ID: 2, Name: "int64_col", Type: Int64Type{}, Required:
false},
+ {ID: 3, Name: "float32_col", Type: Float32Type{},
Required: false},
+ {ID: 4, Name: "float64_col", Type: Float64Type{},
Required: false},
+ {ID: 5, Name: "string_col", Type: StringType{},
Required: false},
+ {ID: 6, Name: "date_col", Type: DateType{}, Required:
false},
+ {ID: 7, Name: "time_col", Type: TimeType{}, Required:
false},
+ {ID: 8, Name: "timestamp_col", Type: TimestampType{},
Required: false},
+ {ID: 9, Name: "timestamptz_col", Type:
TimestampTzType{}, Required: false},
+ {ID: 10, Name: "uuid_col", Type: UUIDType{}, Required:
false},
+ {ID: 11, Name: "bool_col", Type: BooleanType{},
Required: false},
+ {ID: 12, Name: "binary_col", Type: BinaryType{},
Required: false},
+ {ID: 13, Name: "fixed_col", Type: FixedType{len: 16},
Required: false},
+ {ID: 14, Name: "decimal_col", Type:
DecimalType{precision: 10, scale: 2}, Required: false},
+ },
+ }
+
+ partitionData := map[string]any{
+ "int32_col": nil,
+ "int64_col": nil,
+ "float32_col": nil,
+ "float64_col": nil,
+ "string_col": nil,
+ "date_col": nil,
+ "time_col": nil,
+ "timestamp_col": nil,
+ "timestamptz_col": nil,
+ "uuid_col": nil,
+ "bool_col": nil,
+ "binary_col": nil,
+ "fixed_col": nil,
+ "decimal_col": nil,
+ }
+
+ t.Run("nullable schema accepts nil", func(t *testing.T) {
+ schemaNullable, err := partitionTypeToAvroSchema(partitionType)
+ require.NoError(t, err)
+ require.NotNil(t, schemaNullable)
+
+ encoded, err := avro.Marshal(schemaNullable, partitionData)
+ require.NoError(t, err)
+ require.NotEmpty(t, encoded)
+
+ var decoded map[string]any
+ err = avro.Unmarshal(schemaNullable, encoded, &decoded)
+ require.NoError(t, err)
+
+ assert.Nil(t, decoded["int32_col"])
+ assert.Nil(t, decoded["int64_col"])
+ assert.Nil(t, decoded["float32_col"])
+ assert.Nil(t, decoded["float64_col"])
+ assert.Nil(t, decoded["string_col"])
+ assert.Nil(t, decoded["date_col"])
+ assert.Nil(t, decoded["time_col"])
+ assert.Nil(t, decoded["timestamp_col"])
+ assert.Nil(t, decoded["timestamptz_col"])
+ assert.Nil(t, decoded["uuid_col"])
+ assert.Nil(t, decoded["bool_col"])
+ assert.Nil(t, decoded["binary_col"])
+ assert.Nil(t, decoded["fixed_col"])
+ assert.Nil(t, decoded["decimal_col"])
+ })
+
+ t.Run("non-nullable schema rejects nil", func(t *testing.T) {
+ schemaNonNullable, err :=
partitionTypeToAvroSchemaNonNullable(partitionType)
+ require.NoError(t, err)
+ require.NotNil(t, schemaNonNullable)
+
+ encoded, err := avro.Marshal(schemaNonNullable, partitionData)
+ require.Error(t, err, "expected marshal to fail when values are
nil for non-nullable schema")
+ assert.Empty(t, encoded)
+ })
+}
diff --git a/table/arrow_utils.go b/table/arrow_utils.go
index f64b0bce..0f107122 100644
--- a/table/arrow_utils.go
+++ b/table/arrow_utils.go
@@ -32,6 +32,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/iceberg-go"
+ "github.com/apache/iceberg-go/config"
"github.com/apache/iceberg-go/internal"
iceio "github.com/apache/iceberg-go/io"
tblutils "github.com/apache/iceberg-go/table/internal"
@@ -1222,7 +1223,23 @@ func filesToDataFiles(ctx context.Context, fileIO
iceio.IO, meta *MetadataBuilde
statistics :=
format.DataFileStatsFromMeta(rdr.Metadata(),
must(computeStatsPlan(currentSchema, meta.props)),
must(format.PathToIDMapping(currentSchema)))
- df := statistics.ToDataFile(currentSchema, currentSpec,
filePath, iceberg.ParquetFile, rdr.SourceFileSize())
+ partitionValues := make(map[int]any)
+ if !currentSpec.Equals(*iceberg.UnpartitionedSpec) {
+ for field := range currentSpec.Fields() {
+ if !field.Transform.PreservesOrder() {
+ yield(nil, fmt.Errorf("cannot
infer partition value from parquet metadata for a non-linear partition field:
%s with transform %s", field.Name, field.Transform))
+
+ return
+ }
+
+ partitionVal :=
statistics.PartitionValue(field, currentSchema)
+ if partitionVal != nil {
+ partitionValues[field.FieldID]
= partitionVal
+ }
+ }
+ }
+
+ df := statistics.ToDataFile(currentSchema, currentSpec,
filePath, iceberg.ParquetFile, rdr.SourceFileSize(), partitionValues)
if !yield(df, nil) {
return
}
@@ -1298,6 +1315,7 @@ func recordsToDataFiles(ctx context.Context, rootLocation
string, meta *Metadata
if err != nil || currentSpec == nil {
panic(fmt.Errorf("%w: cannot write files without a current
spec", err))
}
+
nextCount, stopCount := iter.Pull(args.counter)
if currentSpec.IsUnpartitioned() {
tasks := func(yield func(WriteTask) bool) {
@@ -1317,8 +1335,13 @@ func recordsToDataFiles(ctx context.Context,
rootLocation string, meta *Metadata
}
}
- return writeFiles(ctx, rootLocation, args.fs, meta, tasks)
- }
+ return writeFiles(ctx, rootLocation, args.fs, meta, nil, tasks)
+ } else {
+ partitionWriter := newPartitionedFanoutWriter(*currentSpec,
meta.CurrentSchema(), args.itr)
+ rollingDataWriters := NewWriterFactory(rootLocation, args,
meta, taskSchema, targetFileSize)
+ partitionWriter.writers = &rollingDataWriters
+ workers := config.EnvConfig.MaxWorkers
- panic(fmt.Errorf("%w: write stream with partitions",
iceberg.ErrNotImplemented))
+ return partitionWriter.Write(ctx, workers)
+ }
}
diff --git a/table/arrow_utils_internal_test.go
b/table/arrow_utils_internal_test.go
index 954ee5de..4305642e 100644
--- a/table/arrow_utils_internal_test.go
+++ b/table/arrow_utils_internal_test.go
@@ -200,7 +200,7 @@ func (suite *FileStatsMetricsSuite) getDataFile(meta
iceberg.Properties, writeSt
stats := format.DataFileStatsFromMeta(fileMeta, collector, mapping)
return stats.ToDataFile(tableMeta.CurrentSchema(),
tableMeta.PartitionSpec(), "fake-path.parquet",
- iceberg.ParquetFile, fileMeta.GetSourceFileSize())
+ iceberg.ParquetFile, fileMeta.GetSourceFileSize(), nil)
}
func (suite *FileStatsMetricsSuite) TestRecordCount() {
diff --git a/table/internal/interfaces.go b/table/internal/interfaces.go
index bd474775..0fdb8097 100644
--- a/table/internal/interfaces.go
+++ b/table/internal/interfaces.go
@@ -78,7 +78,7 @@ type FileFormat interface {
PathToIDMapping(*iceberg.Schema) (map[string]int, error)
DataFileStatsFromMeta(rdr Metadata, statsCols
map[int]StatisticsCollector, colMapping map[string]int) *DataFileStatistics
GetWriteProperties(iceberg.Properties) any
- WriteDataFile(ctx context.Context, fs iceio.WriteFileIO, info
WriteFileInfo, batches []arrow.RecordBatch) (iceberg.DataFile, error)
+ WriteDataFile(ctx context.Context, fs iceio.WriteFileIO,
partitionValues map[int]any, info WriteFileInfo, batches []arrow.RecordBatch)
(iceberg.DataFile, error)
}
func GetFileFormat(format iceberg.FileFormat) FileFormat {
diff --git a/table/internal/parquet_files.go b/table/internal/parquet_files.go
index 0c2cb8b5..e5f7c76e 100644
--- a/table/internal/parquet_files.go
+++ b/table/internal/parquet_files.go
@@ -236,7 +236,7 @@ func (parquetFormat) GetWriteProperties(props
iceberg.Properties) any {
parquet.WithCompressionLevel(compressionLevel))
}
-func (p parquetFormat) WriteDataFile(ctx context.Context, fs
iceio.WriteFileIO, info WriteFileInfo, batches []arrow.RecordBatch)
(iceberg.DataFile, error) {
+func (p parquetFormat) WriteDataFile(ctx context.Context, fs
iceio.WriteFileIO, partitionValues map[int]any, info WriteFileInfo, batches
[]arrow.RecordBatch) (iceberg.DataFile, error) {
fw, err := fs.Create(info.FileName)
if err != nil {
return nil, err
@@ -274,7 +274,7 @@ func (p parquetFormat) WriteDataFile(ctx context.Context,
fs iceio.WriteFileIO,
}
return p.DataFileStatsFromMeta(filemeta, info.StatsCols, colMapping).
- ToDataFile(info.FileSchema, info.Spec, info.FileName,
iceberg.ParquetFile, cntWriter.Count), nil
+ ToDataFile(info.FileSchema, info.Spec, info.FileName,
iceberg.ParquetFile, cntWriter.Count, partitionValues), nil
}
type decAsIntAgg[T int32 | int64] struct {
diff --git a/table/internal/parquet_files_test.go
b/table/internal/parquet_files_test.go
index ac16db3f..19db8a72 100644
--- a/table/internal/parquet_files_test.go
+++ b/table/internal/parquet_files_test.go
@@ -256,7 +256,7 @@ func TestMetricsPrimitiveTypes(t *testing.T) {
stats := format.DataFileStatsFromMeta(internal.Metadata(meta),
getCollector(), mapping)
df := stats.ToDataFile(tblMeta.CurrentSchema(),
tblMeta.PartitionSpec(), "fake-path.parquet",
- iceberg.ParquetFile, meta.GetSourceFileSize())
+ iceberg.ParquetFile, meta.GetSourceFileSize(), nil)
assert.Len(t, df.ValueCounts(), 15)
assert.Len(t, df.NullValueCounts(), 15)
diff --git a/table/internal/utils.go b/table/internal/utils.go
index 82363e09..6227d746 100644
--- a/table/internal/utils.go
+++ b/table/internal/utils.go
@@ -34,6 +34,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/decimal"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/iceberg-go"
+ "github.com/hamba/avro/v2"
"golang.org/x/sync/errgroup"
)
@@ -212,8 +213,7 @@ func (d *DataFileStatistics) PartitionValue(field
iceberg.PartitionField, sc *ic
}
if !field.Transform.PreservesOrder() {
- panic(fmt.Errorf("cannot infer partition value from parquet
metadata for a non-linear partition field: %s with transform %s",
- field.Name, field.Transform))
+ return nil
}
lowerRec := must(PartitionRecordValue(field, agg.Min(), sc))
@@ -234,20 +234,50 @@ func (d *DataFileStatistics) PartitionValue(field
iceberg.PartitionField, sc *ic
return lowerT.Val.Any()
}
-func (d *DataFileStatistics) ToDataFile(schema *iceberg.Schema, spec
iceberg.PartitionSpec, path string, format iceberg.FileFormat, filesize int64)
iceberg.DataFile {
+func (d *DataFileStatistics) ToDataFile(schema *iceberg.Schema, spec
iceberg.PartitionSpec, path string, format iceberg.FileFormat, filesize int64,
partitionValues map[int]any) iceberg.DataFile {
var fieldIDToPartitionData map[int]any
+ fieldIDToLogicalType := make(map[int]avro.LogicalType)
+ fieldIDToFixedSize := make(map[int]int)
+
if !spec.Equals(*iceberg.UnpartitionedSpec) {
fieldIDToPartitionData = make(map[int]any)
for field := range spec.Fields() {
- val := d.PartitionValue(field, schema)
- if val != nil {
- fieldIDToPartitionData[field.FieldID] = val
+ partitionVal := partitionValues[field.FieldID]
+ if partitionVal != nil {
+ val := d.PartitionValue(field, schema)
+ if val != nil {
+ fieldIDToPartitionData[field.FieldID] =
val
+ } else {
+ fieldIDToPartitionData[field.FieldID] =
partitionVal
+ }
+ } else {
+ fieldIDToPartitionData[field.FieldID] = nil
+ }
+
+ if sourceField, ok :=
schema.FindFieldByID(field.SourceID); ok {
+ resultType :=
field.Transform.ResultType(sourceField.Type)
+
+ switch rt := resultType.(type) {
+ case iceberg.DateType:
+ fieldIDToLogicalType[field.FieldID] =
avro.Date
+ case iceberg.TimeType:
+ fieldIDToLogicalType[field.FieldID] =
avro.TimeMicros
+ case iceberg.TimestampType:
+ fieldIDToLogicalType[field.FieldID] =
avro.TimestampMicros
+ case iceberg.TimestampTzType:
+ fieldIDToLogicalType[field.FieldID] =
avro.TimestampMicros
+ case iceberg.DecimalType:
+ fieldIDToLogicalType[field.FieldID] =
avro.Decimal
+ fieldIDToFixedSize[field.FieldID] =
rt.Scale()
+ case iceberg.UUIDType:
+ fieldIDToLogicalType[field.FieldID] =
avro.UUID
+ }
}
}
}
bldr, err := iceberg.NewDataFileBuilder(spec, iceberg.EntryContentData,
- path, format, fieldIDToPartitionData, d.RecordCount, filesize)
+ path, format, fieldIDToPartitionData, fieldIDToLogicalType,
fieldIDToFixedSize, d.RecordCount, filesize)
if err != nil {
panic(err)
}
diff --git a/table/partitioned_fanout_writer.go
b/table/partitioned_fanout_writer.go
new file mode 100644
index 00000000..a4b41d3c
--- /dev/null
+++ b/table/partitioned_fanout_writer.go
@@ -0,0 +1,333 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package table
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "iter"
+
+ "github.com/apache/arrow-go/v18/arrow"
+ "github.com/apache/arrow-go/v18/arrow/array"
+ "github.com/apache/arrow-go/v18/arrow/compute"
+ "github.com/apache/arrow-go/v18/arrow/decimal128"
+ "github.com/apache/arrow-go/v18/arrow/extensions"
+ "github.com/apache/iceberg-go"
+ "golang.org/x/sync/errgroup"
+)
+
+// PartitionedFanoutWriter distributes Arrow records across multiple
partitions based on
+// a partition specification, writing data to separate files for each
partition using
+// a fanout pattern with configurable parallelism.
+type partitionedFanoutWriter struct {
+ partitionSpec iceberg.PartitionSpec
+ schema *iceberg.Schema
+ itr iter.Seq2[arrow.RecordBatch, error]
+ writers *writerFactory
+}
+
+// PartitionInfo holds the row indices and partition values for a specific
partition,
+// used during the fanout process to group rows by their partition key.
+type partitionInfo struct {
+ rows []int64
+ partitionValues map[int]any
+}
+
+// NewPartitionedFanoutWriter creates a new PartitionedFanoutWriter with the
specified
+// partition specification, schema, and record iterator.
+func newPartitionedFanoutWriter(partitionSpec iceberg.PartitionSpec, schema
*iceberg.Schema, itr iter.Seq2[arrow.RecordBatch, error])
*partitionedFanoutWriter {
+ return &partitionedFanoutWriter{
+ partitionSpec: partitionSpec,
+ schema: schema,
+ itr: itr,
+ }
+}
+
+func (p *partitionedFanoutWriter) partitionPath(data partitionRecord) string {
+ return p.partitionSpec.PartitionToPath(data, p.schema)
+}
+
+// Write writes the Arrow records to the specified location using a fanout
pattern with
+// the specified number of workers. The returned iterator yields the data
files written
+// by the fanout process.
+func (p *partitionedFanoutWriter) Write(ctx context.Context, workers int)
iter.Seq2[iceberg.DataFile, error] {
+ inputRecordsCh := make(chan arrow.RecordBatch, workers)
+ outputDataFilesCh := make(chan iceberg.DataFile, workers)
+
+ fanoutWorkers, ctx := errgroup.WithContext(ctx)
+ p.startRecordFeeder(ctx, fanoutWorkers, inputRecordsCh)
+
+ for range workers {
+ fanoutWorkers.Go(func() error {
+ return p.fanout(ctx, inputRecordsCh, outputDataFilesCh)
+ })
+ }
+
+ return p.yieldDataFiles(fanoutWorkers, outputDataFilesCh)
+}
+
+func (p *partitionedFanoutWriter) startRecordFeeder(ctx context.Context,
fanoutWorkers *errgroup.Group, inputRecordsCh chan<- arrow.RecordBatch) {
+ fanoutWorkers.Go(func() error {
+ defer close(inputRecordsCh)
+
+ for record, err := range p.itr {
+ if err != nil {
+ return err
+ }
+
+ record.Retain()
+ select {
+ case <-ctx.Done():
+ record.Release()
+
+ return context.Cause(ctx)
+ case inputRecordsCh <- record:
+ }
+ }
+
+ return nil
+ })
+}
+
+func (p *partitionedFanoutWriter) fanout(ctx context.Context, inputRecordsCh
<-chan arrow.RecordBatch, dataFilesChannel chan<- iceberg.DataFile) error {
+ for {
+ select {
+ case <-ctx.Done():
+ return context.Cause(ctx)
+
+ case record, ok := <-inputRecordsCh:
+ if !ok {
+ return nil
+ }
+ defer record.Release()
+
+ partitionMap, err := p.getPartitionMap(record)
+ if err != nil {
+ return err
+ }
+
+ for partition, val := range partitionMap {
+ select {
+ case <-ctx.Done():
+ return context.Cause(ctx)
+ default:
+ }
+
+ partitionRecord, err :=
partitionBatchByKey(ctx)(record, val.rows)
+ if err != nil {
+ return err
+ }
+
+ rollingDataWriter, err :=
p.writers.getOrCreateRollingDataWriter(ctx, partition, val.partitionValues,
dataFilesChannel)
+ if err != nil {
+ return err
+ }
+
+ err = rollingDataWriter.Add(partitionRecord)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+}
+
+func (p *partitionedFanoutWriter) yieldDataFiles(fanoutWorkers
*errgroup.Group, outputDataFilesCh chan iceberg.DataFile)
iter.Seq2[iceberg.DataFile, error] {
+ var err error
+ go func() {
+ defer close(outputDataFilesCh)
+ err = fanoutWorkers.Wait()
+ err = errors.Join(err, p.writers.closeAll())
+ }()
+
+ return func(yield func(iceberg.DataFile, error) bool) {
+ defer func() {
+ for range outputDataFilesCh {
+ }
+ }()
+
+ for f := range outputDataFilesCh {
+ if !yield(f, err) {
+ return
+ }
+ }
+
+ if err != nil {
+ yield(nil, err)
+ }
+ }
+}
+
+func (p *partitionedFanoutWriter) getPartitionMap(record arrow.RecordBatch)
(map[string]partitionInfo, error) {
+ partitionMap := make(map[string]partitionInfo)
+ partitionFields := p.partitionSpec.PartitionType(p.schema).FieldList
+ partitionRec := make(partitionRecord, len(partitionFields))
+
+ partitionColumns := make([]arrow.Array, len(partitionFields))
+ partitionFieldsInfo := make([]struct {
+ sourceField *iceberg.PartitionField
+ fieldID int
+ }, len(partitionFields))
+
+ for i := range partitionFields {
+ sourceField := p.partitionSpec.Field(i)
+ colName, _ := p.schema.FindColumnName(sourceField.SourceID)
+ colIdx := record.Schema().FieldIndices(colName)[0]
+ partitionColumns[i] = record.Column(colIdx)
+ partitionFieldsInfo[i] = struct {
+ sourceField *iceberg.PartitionField
+ fieldID int
+ }{&sourceField, sourceField.FieldID}
+ }
+
+ for row := range record.NumRows() {
+ partitionValues := make(map[int]any)
+ for i := range partitionFields {
+ col := partitionColumns[i]
+ if !col.IsNull(int(row)) {
+ sourceField :=
partitionFieldsInfo[i].sourceField
+ val, err := getArrowValueAsIcebergLiteral(col,
int(row))
+ if err != nil {
+ return nil, fmt.Errorf("failed to get
arrow values as iceberg literal: %w", err)
+ }
+
+ transformedLiteral :=
sourceField.Transform.Apply(iceberg.Optional[iceberg.Literal]{Valid: true, Val:
val})
+ if transformedLiteral.Valid {
+ partitionRec[i] =
transformedLiteral.Val.Any()
+ partitionValues[sourceField.FieldID] =
transformedLiteral.Val.Any()
+ } else {
+ partitionRec[i],
partitionValues[sourceField.FieldID] = nil, nil
+ }
+ } else {
+ partitionRec[i],
partitionValues[partitionFieldsInfo[i].fieldID] = nil, nil
+ }
+ }
+ partitionKey := p.partitionPath(partitionRec)
+ partVal := partitionMap[partitionKey]
+ partVal.rows = append(partitionMap[partitionKey].rows, row)
+ partVal.partitionValues = partitionValues
+ partitionMap[partitionKey] = partVal
+ }
+
+ return partitionMap, nil
+}
+
+type partitionBatchFn func(arrow.RecordBatch, []int64) (arrow.RecordBatch,
error)
+
+func partitionBatchByKey(ctx context.Context) partitionBatchFn {
+ mem := compute.GetAllocator(ctx)
+
+ return func(record arrow.RecordBatch, rowIndices []int64)
(arrow.RecordBatch, error) {
+ bldr := array.NewInt64Builder(mem)
+ defer bldr.Release()
+
+ bldr.AppendValues(rowIndices, nil)
+ rowIndicesArr := bldr.NewInt64Array()
+ defer rowIndicesArr.Release()
+
+ partitionedRecord, err := compute.Take(
+ ctx,
+ *compute.DefaultTakeOptions(),
+ compute.NewDatumWithoutOwning(record),
+ compute.NewDatumWithoutOwning(rowIndicesArr),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return partitionedRecord.(*compute.RecordDatum).Value, nil
+ }
+}
+
+func getArrowValueAsIcebergLiteral(column arrow.Array, row int)
(iceberg.Literal, error) {
+ if column.IsNull(row) {
+ return nil, nil
+ }
+
+ switch arr := column.(type) {
+ case *array.Date32:
+
+ return iceberg.NewLiteral(iceberg.Date(arr.Value(row))), nil
+ case *array.Time64:
+
+ return iceberg.NewLiteral(iceberg.Time(arr.Value(row))), nil
+ case *array.Timestamp:
+
+ return iceberg.NewLiteral(iceberg.Timestamp(arr.Value(row))),
nil
+ case *array.Decimal32:
+ val := arr.Value(row)
+ dec := iceberg.Decimal{
+ Val: decimal128.FromU64(uint64(val)),
+ Scale: int(arr.DataType().(*arrow.Decimal32Type).Scale),
+ }
+
+ return iceberg.NewLiteral(dec), nil
+ case *array.Decimal64:
+ val := arr.Value(row)
+ dec := iceberg.Decimal{
+ Val: decimal128.FromU64(uint64(val)),
+ Scale: int(arr.DataType().(*arrow.Decimal64Type).Scale),
+ }
+
+ return iceberg.NewLiteral(dec), nil
+ case *array.Decimal128:
+ val := arr.Value(row)
+ dec := iceberg.Decimal{
+ Val: val,
+ Scale:
int(arr.DataType().(*arrow.Decimal128Type).Scale),
+ }
+
+ return iceberg.NewLiteral(dec), nil
+ case *extensions.UUIDArray:
+
+ return iceberg.NewLiteral(arr.Value(row)), nil
+ default:
+ val := column.GetOneForMarshal(row)
+ switch v := val.(type) {
+ case bool:
+ return iceberg.NewLiteral(v), nil
+ case int8:
+ return iceberg.NewLiteral(int32(v)), nil
+ case uint8:
+ return iceberg.NewLiteral(int32(v)), nil
+ case int16:
+ return iceberg.NewLiteral(int32(v)), nil
+ case uint16:
+ return iceberg.NewLiteral(int32(v)), nil
+ case int32:
+ return iceberg.NewLiteral(v), nil
+ case uint32:
+ return iceberg.NewLiteral(int32(v)), nil
+ case int64:
+ return iceberg.NewLiteral(v), nil
+ case uint64:
+ return iceberg.NewLiteral(int64(v)), nil
+ case float32:
+ return iceberg.NewLiteral(v), nil
+ case float64:
+ return iceberg.NewLiteral(v), nil
+ case string:
+ return iceberg.NewLiteral(v), nil
+ case []byte:
+ return iceberg.NewLiteral(v), nil
+ default:
+ return nil, fmt.Errorf("unsupported value type: %T", v)
+ }
+ }
+}
diff --git a/table/partitioned_fanout_writer_test.go
b/table/partitioned_fanout_writer_test.go
new file mode 100644
index 00000000..09c172e7
--- /dev/null
+++ b/table/partitioned_fanout_writer_test.go
@@ -0,0 +1,417 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package table
+
+import (
+ "context"
+ "fmt"
+ "path/filepath"
+ "reflect"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/apache/arrow-go/v18/arrow"
+ "github.com/apache/arrow-go/v18/arrow/array"
+ arrowdecimal "github.com/apache/arrow-go/v18/arrow/decimal"
+ "github.com/apache/arrow-go/v18/arrow/extensions"
+ "github.com/apache/arrow-go/v18/arrow/memory"
+ "github.com/apache/iceberg-go"
+ "github.com/apache/iceberg-go/config"
+
+ iceio "github.com/apache/iceberg-go/io"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/suite"
+)
+
+type FanoutWriterTestSuite struct {
+ suite.Suite
+
+ mem memory.Allocator
+ ctx context.Context
+}
+
+func (s *FanoutWriterTestSuite) SetupTest() {
+ s.ctx = context.Background()
+ s.mem = memory.NewCheckedAllocator(memory.NewGoAllocator())
+}
+
+func TestFanoutWriter(t *testing.T) {
+ suite.Run(t, new(FanoutWriterTestSuite))
+}
+
+func (s *FanoutWriterTestSuite) createCustomTestRecord(arrSchema
*arrow.Schema, data [][]any) arrow.RecordBatch {
+ bldr := array.NewRecordBuilder(s.mem, arrSchema)
+ defer bldr.Release()
+
+ for _, row := range data {
+ for i, val := range row {
+ field := bldr.Field(i)
+
+ if val == nil {
+ field.AppendNull()
+
+ continue
+ }
+
+ v := reflect.ValueOf(val)
+ appendMethod :=
reflect.ValueOf(field).MethodByName("Append")
+
+ switch t := val.(type) {
+ case uuid.UUID:
+ field.(*extensions.UUIDBuilder).Append(t)
+ case []byte:
+ field.(*array.BinaryBuilder).Append(t)
+ default:
+ appendMethod.Call([]reflect.Value{v})
+ }
+ }
+ }
+
+ return bldr.NewRecordBatch()
+}
+
+func (s *FanoutWriterTestSuite) testTransformPartition(transform
iceberg.Transform, sourceFieldName string, transformName string, testRecord
arrow.RecordBatch, expectedPartitionCount int) {
+ icebergSchema, err :=
ArrowSchemaToIcebergWithFreshIDs(testRecord.Schema(), false)
+ s.Require().NoError(err, "Failed to convert Arrow Schema to Iceberg
Schema")
+
+ sourceField, ok := icebergSchema.FindFieldByName(sourceFieldName)
+ s.Require().True(ok, "Source field %s not found in schema",
sourceFieldName)
+
+ spec := iceberg.NewPartitionSpec(
+ iceberg.PartitionField{
+ SourceID: sourceField.ID,
+ FieldID: 1000,
+ Transform: transform,
+ Name: "test_%s" + transformName,
+ },
+ )
+
+ loc := filepath.ToSlash(s.T().TempDir())
+ meta, err := NewMetadata(icebergSchema, &spec, UnsortedSortOrder, loc,
iceberg.Properties{})
+ s.Require().NoError(err)
+
+ metaBuilder, err := MetadataBuilderFromBase(meta)
+ s.Require().NoError(err)
+
+ args := recordWritingArgs{
+ sc: testRecord.Schema(),
+ itr: func(yield func(arrow.RecordBatch, error) bool) {
+ testRecord.Retain()
+ yield(testRecord, nil)
+ },
+ fs: iceio.LocalFS{},
+ writeUUID: func() *uuid.UUID {
+ u := uuid.New()
+
+ return &u
+ }(),
+ counter: func(yield func(int) bool) {
+ for i := 0; ; i++ {
+ if !yield(i) {
+ break
+ }
+ }
+ },
+ }
+
+ nameMapping := icebergSchema.NameMapping()
+ taskSchema, err := ArrowSchemaToIceberg(args.sc, false, nameMapping)
+ s.Require().NoError(err)
+
+ partitionWriter := newPartitionedFanoutWriter(spec, taskSchema,
args.itr)
+ rollingDataWriters := NewWriterFactory(loc, args, metaBuilder,
icebergSchema, 1024*1024)
+
+ partitionWriter.writers = &rollingDataWriters
+ workers := config.EnvConfig.MaxWorkers
+
+ dataFiles := partitionWriter.Write(s.ctx, workers)
+
+ fileCount := 0
+ totalRecords := int64(0)
+ partitionPaths := make(map[string]int64)
+
+ for dataFile, err := range dataFiles {
+ s.Require().NoError(err, "Transform %s should work",
transformName)
+ s.NotNil(dataFile)
+ fileCount++
+ totalRecords += dataFile.Count()
+
+ partitionRec := getPartitionRecord(dataFile,
spec.PartitionType(icebergSchema))
+ partitionPath := spec.PartitionToPath(partitionRec,
icebergSchema)
+ partitionPaths[partitionPath] += dataFile.Count()
+ }
+
+ s.Equal(expectedPartitionCount, fileCount, "Expected %d files, got %d",
expectedPartitionCount, fileCount)
+ s.Equal(totalRecords, testRecord.NumRows(), "Expected %d records, got
%d", testRecord.NumRows(), totalRecords)
+
+ s.T().Logf("Transform %s created %d partitions with distribution: %v",
transformName, fileCount, partitionPaths)
+}
+
+func (s *FanoutWriterTestSuite) TestIdentityTransform() {
+ arrSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "name", Type: arrow.BinaryTypes.String, Nullable: true},
+ }, nil)
+
+ testRecord := s.createCustomTestRecord(arrSchema, [][]any{
+ {int32(1), "partition_a"},
+ {int32(2), "partition_b"},
+ {int32(3), "partition_a"},
+ {int32(4), "partition_b"},
+ {nil, nil},
+ })
+ defer testRecord.Release()
+
+ s.testTransformPartition(iceberg.IdentityTransform{}, "name",
"identity", testRecord, 3)
+}
+
+func (s *FanoutWriterTestSuite) TestBucketTransform() {
+ arrSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "name", Type: arrow.BinaryTypes.String, Nullable: true},
+ }, nil)
+
+ testRecord := s.createCustomTestRecord(arrSchema, [][]any{
+ {int32(1), "partition_a"},
+ {int32(2), "partition_b"},
+ {int32(3), "partition_a"},
+ {int32(4), "partition_b"},
+ {nil, nil},
+ })
+ defer testRecord.Release()
+
+ s.testTransformPartition(iceberg.BucketTransform{NumBuckets: 3}, "id",
"bucket", testRecord, 3)
+}
+
+func (s *FanoutWriterTestSuite) TestTruncateTransform() {
+ arrSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "name", Type: arrow.BinaryTypes.String, Nullable: true},
+ }, nil)
+
+ testRecord := s.createCustomTestRecord(arrSchema, [][]any{
+ {int32(1), "abcdef"},
+ {int32(2), "abcxyz"},
+ {int32(3), "abcuvw"},
+ {int32(4), "defghi"},
+ {nil, nil},
+ })
+ defer testRecord.Release()
+
+ s.testTransformPartition(iceberg.TruncateTransform{Width: 3}, "name",
"truncate", testRecord, 3)
+}
+
+func (s *FanoutWriterTestSuite) TestYearTransform() {
+ arrSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "created_date", Type: arrow.PrimitiveTypes.Date32,
Nullable: true},
+ }, nil)
+
+ testRecord := s.createCustomTestRecord(arrSchema, [][]any{
+ {int32(1), arrow.Date32(19358)},
+ {int32(2), arrow.Date32(19723)},
+ {int32(3), arrow.Date32(19400)},
+ {int32(4), arrow.Date32(19800)},
+ {nil, nil},
+ })
+ defer testRecord.Release()
+
+ s.testTransformPartition(iceberg.YearTransform{}, "created_date",
"year", testRecord, 3)
+}
+
+func (s *FanoutWriterTestSuite) TestMonthTransform() {
+ arrSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "created_date", Type: arrow.PrimitiveTypes.Date32,
Nullable: true},
+ }, nil)
+
+ testRecord := s.createCustomTestRecord(arrSchema, [][]any{
+ {int32(1), arrow.Date32(19358)},
+ {int32(2), arrow.Date32(19386)},
+ {int32(3), arrow.Date32(19389)},
+ {int32(4), arrow.Date32(19416)},
+ {nil, nil},
+ })
+ defer testRecord.Release()
+
+ s.testTransformPartition(iceberg.MonthTransform{}, "created_date",
"month", testRecord, 3)
+}
+
+func (s *FanoutWriterTestSuite) TestDayTransform() {
+ arrSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "created_date", Type: arrow.PrimitiveTypes.Date32,
Nullable: true},
+ }, nil)
+
+ testRecord := s.createCustomTestRecord(arrSchema, [][]any{
+ {int32(1), arrow.Date32(19358)},
+ {int32(2), arrow.Date32(19358)},
+ {int32(3), arrow.Date32(19359)},
+ {int32(4), arrow.Date32(19359)},
+ {nil, nil},
+ })
+ defer testRecord.Release()
+
+ s.testTransformPartition(iceberg.DayTransform{}, "created_date", "day",
testRecord, 3)
+}
+
+func (s *FanoutWriterTestSuite) TestHourTransform() {
+ arrSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "created_ts", Type: &arrow.TimestampType{Unit:
arrow.Microsecond}, Nullable: true},
+ }, nil)
+
+ testRecord := s.createCustomTestRecord(arrSchema, [][]any{
+ {int32(1), arrow.Timestamp(1672531200000000)},
+ {int32(2), arrow.Timestamp(1672531800000000)},
+ {int32(3), arrow.Timestamp(1672534800000000)},
+ {int32(4), arrow.Timestamp(1672535400000000)},
+ {nil, nil},
+ })
+ defer testRecord.Release()
+
+ s.testTransformPartition(iceberg.HourTransform{}, "created_ts", "hour",
testRecord, 3)
+}
+
+func (s *FanoutWriterTestSuite) TestVoidTransform() {
+ arrSchema := arrow.NewSchema([]arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "nothing", Type: arrow.PrimitiveTypes.Int32, Nullable:
true},
+ }, nil)
+
+ testRecord := s.createCustomTestRecord(arrSchema, [][]any{
+ {int32(1), int32(100)},
+ {int32(2), int32(200)},
+ {int32(3), int32(300)},
+ {int32(4), int32(400)},
+ {nil, nil},
+ })
+ defer testRecord.Release()
+
+ s.testTransformPartition(iceberg.VoidTransform{}, "nothing", "void",
testRecord, 1)
+}
+
+func (s *FanoutWriterTestSuite)
TestPartitionedLogicalTypesRequireIntFieldIDCase() {
+ icebergSchema := iceberg.NewSchemaWithIdentifiers(1, []int{1},
+ iceberg.NestedField{ID: 1, Name: "id", Type:
iceberg.PrimitiveTypes.Int64, Required: true},
+ iceberg.NestedField{ID: 2, Name: "decimal_col", Type:
iceberg.DecimalTypeOf(10, 6), Required: true},
+ iceberg.NestedField{ID: 3, Name: "time_col", Type:
iceberg.PrimitiveTypes.Time, Required: true},
+ iceberg.NestedField{ID: 4, Name: "timestamp_col", Type:
iceberg.PrimitiveTypes.Timestamp, Required: true},
+ iceberg.NestedField{ID: 5, Name: "timestamptz_col", Type:
iceberg.PrimitiveTypes.TimestampTz, Required: true},
+ iceberg.NestedField{ID: 6, Name: "uuid_col", Type:
iceberg.PrimitiveTypes.UUID, Required: true},
+ iceberg.NestedField{ID: 7, Name: "date_col", Type:
iceberg.PrimitiveTypes.Date, Required: true},
+ )
+
+ spec := iceberg.NewPartitionSpec(
+ iceberg.PartitionField{SourceID: 2, FieldID: 4008, Transform:
iceberg.IdentityTransform{}, Name: "decimal_col"},
+ iceberg.PartitionField{SourceID: 3, FieldID: 4009, Transform:
iceberg.IdentityTransform{}, Name: "time_col"},
+ iceberg.PartitionField{SourceID: 4, FieldID: 4010, Transform:
iceberg.IdentityTransform{}, Name: "timestamp_col"},
+ iceberg.PartitionField{SourceID: 5, FieldID: 4011, Transform:
iceberg.IdentityTransform{}, Name: "timestamptz_col"},
+ iceberg.PartitionField{SourceID: 6, FieldID: 4014, Transform:
iceberg.IdentityTransform{}, Name: "uuid_col"},
+ iceberg.PartitionField{SourceID: 7, FieldID: 4015, Transform:
iceberg.IdentityTransform{}, Name: "date_col"},
+ )
+
+ loc := filepath.ToSlash(s.T().TempDir())
+ meta, err := NewMetadata(icebergSchema, &spec, UnsortedSortOrder, loc,
iceberg.Properties{})
+ s.Require().NoError(err)
+
+ tbl := New(
+ Identifier{"test", "table"},
+ meta,
+ filepath.Join(loc, "metadata", "v1.metadata.json"),
+ func(ctx context.Context) (iceio.IO, error) { return
iceio.LocalFS{}, nil },
+ nil,
+ )
+
+ record := s.createComprehensiveTestRecord()
+ defer record.Release()
+ arrowTable := array.NewTableFromRecords(record.Schema(),
[]arrow.RecordBatch{record})
+ defer arrowTable.Release()
+
+ snapshotProps := iceberg.Properties{
+ "operation": "append",
+ "source": "iceberg-go-fanout-test",
+ "timestamp": strconv.FormatInt(time.Now().Unix(), 10),
+ "rows-added": strconv.FormatInt(int64(arrowTable.NumRows()),
10),
+ }
+
+ batchSize := int64(record.NumRows())
+ txn := tbl.NewTransaction()
+ err = txn.AppendTable(s.ctx, arrowTable, batchSize, snapshotProps)
+ s.Require().NoError(err, "AppendTable should succeed with all primitive
types")
+}
+
+func (s *FanoutWriterTestSuite) createComprehensiveTestRecord()
arrow.RecordBatch {
+ pool := s.mem
+
+ fields := []arrow.Field{
+ {Name: "id", Type: arrow.PrimitiveTypes.Int64},
+ {Name: "decimal_col", Type: &arrow.Decimal128Type{Precision:
10, Scale: 6}},
+ {Name: "time_col", Type: arrow.FixedWidthTypes.Time64us},
+ {Name: "timestamp_col", Type: &arrow.TimestampType{Unit:
arrow.Microsecond}},
+ {Name: "timestamptz_col", Type: &arrow.TimestampType{Unit:
arrow.Microsecond, TimeZone: "UTC"}},
+ {Name: "uuid_col", Type: extensions.NewUUIDType()},
+ {Name: "date_col", Type: arrow.FixedWidthTypes.Date32},
+ }
+ arrSchema := arrow.NewSchema(fields, nil)
+
+ idB := array.NewInt64Builder(pool)
+ decB := array.NewDecimal128Builder(pool,
&arrow.Decimal128Type{Precision: 10, Scale: 6})
+ timeB := array.NewTime64Builder(pool, &arrow.Time64Type{Unit:
arrow.Microsecond})
+ tsB := array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit:
arrow.Microsecond})
+ tstzB := array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit:
arrow.Microsecond, TimeZone: "UTC"})
+ uuidB := extensions.NewUUIDBuilder(pool)
+ dateB := array.NewDate32Builder(pool)
+
+ for i := 0; i < 4; i++ {
+ if i%2 == 0 {
+ idB.Append(int64(i))
+ val := fmt.Sprintf("%d.%06d", 123, i)
+ arrowDec, _ := arrowdecimal.Decimal128FromString(val,
10, 6)
+ decB.Append(arrowDec)
+ timeB.Append(arrow.Time64(time.Duration(i * 1_000_000)))
+ tsB.Append(arrow.Timestamp(1_600_000_000_000_000 +
int64(i)*1_000_000))
+ tstzB.Append(arrow.Timestamp(1_600_000_000_000_000 +
int64(i)*1_000_000))
+ uuidB.Append(uuid.New())
+ dateB.Append(arrow.Date32(20000 + i))
+ } else {
+ idB.Append(int64(i))
+ decB.AppendNull()
+ timeB.AppendNull()
+ tsB.AppendNull()
+ tstzB.AppendNull()
+ uuidB.AppendNull()
+ dateB.AppendNull()
+ }
+ }
+
+ cols := []arrow.Array{
+ idB.NewArray(),
+ decB.NewArray(),
+ timeB.NewArray(),
+ tsB.NewArray(),
+ tstzB.NewArray(),
+ uuidB.NewArray(),
+ dateB.NewArray(),
+ }
+
+ record := array.NewRecordBatch(arrSchema, cols, int64(cols[0].Len()))
+
+ return record
+}
diff --git a/table/rolling_data_writer.go b/table/rolling_data_writer.go
new file mode 100644
index 00000000..ac43d46f
--- /dev/null
+++ b/table/rolling_data_writer.go
@@ -0,0 +1,237 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package table
+
+import (
+ "context"
+ "fmt"
+ "iter"
+ "net/url"
+ "sync"
+ "sync/atomic"
+
+ "github.com/apache/arrow-go/v18/arrow"
+ "github.com/apache/iceberg-go"
+)
+
+// WriterFactory manages the creation and lifecycle of RollingDataWriter
instances
+// for different partitions, providing shared configuration and coordination
+// across all writers in a partitioned write operation.
+type writerFactory struct {
+ rootLocation string
+ args recordWritingArgs
+ meta *MetadataBuilder
+ taskSchema *iceberg.Schema
+ targetFileSize int64
+ writers sync.Map
+ counter atomic.Int64
+ mu sync.Mutex
+}
+
+// NewWriterFactory creates a new WriterFactory with the specified
configuration
+// for managing rolling data writers across partitions.
+func NewWriterFactory(rootLocation string, args recordWritingArgs, meta
*MetadataBuilder, taskSchema *iceberg.Schema, targetFileSize int64)
writerFactory {
+ return writerFactory{
+ rootLocation: rootLocation,
+ args: args,
+ meta: meta,
+ taskSchema: taskSchema,
+ targetFileSize: targetFileSize,
+ }
+}
+
+// RollingDataWriter accumulates Arrow records for a specific partition and
flushes
+// them to data files when the target file size is reached, implementing a
rolling
+// file strategy to manage file sizes.
+type RollingDataWriter struct {
+ partitionKey string
+ recordCh chan arrow.RecordBatch
+ errorCh chan error
+ factory *writerFactory
+ partitionValues map[int]any
+ ctx context.Context
+ cancel context.CancelFunc
+ wg sync.WaitGroup
+}
+
+// NewRollingDataWriter creates a new RollingDataWriter for the specified
partition
+// with the given partition values.
+func (w *writerFactory) NewRollingDataWriter(ctx context.Context, partition
string, partitionValues map[int]any, outputDataFilesCh chan<- iceberg.DataFile)
*RollingDataWriter {
+ ctx, cancel := context.WithCancel(ctx)
+ writer := &RollingDataWriter{
+ partitionKey: partition,
+ recordCh: make(chan arrow.RecordBatch, 64),
+ errorCh: make(chan error, 1),
+ factory: w,
+ partitionValues: partitionValues,
+ ctx: ctx,
+ cancel: cancel,
+ }
+
+ writer.wg.Add(1)
+ go writer.stream(outputDataFilesCh)
+
+ return writer
+}
+
+func (w *writerFactory) getOrCreateRollingDataWriter(ctx context.Context,
partition string, partitionValues map[int]any, outputDataFilesCh chan<-
iceberg.DataFile) (*RollingDataWriter, error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ if existing, ok := w.writers.Load(partition); ok {
+ if writer, ok := existing.(*RollingDataWriter); ok {
+ return writer, nil
+ }
+
+ return nil, fmt.Errorf("invalid writer type for partition: %s",
partition)
+ }
+
+ writer := w.NewRollingDataWriter(ctx, partition, partitionValues,
outputDataFilesCh)
+ w.writers.Store(partition, writer)
+
+ return writer, nil
+}
+
+// Add appends a record to the writer's buffer and flushes to a data file if
the
+// target file size is reached.
+func (r *RollingDataWriter) Add(record arrow.RecordBatch) error {
+ record.Retain()
+ select {
+ case r.recordCh <- record:
+ return nil
+ case err := <-r.errorCh:
+ record.Release()
+
+ return err
+ case <-r.ctx.Done():
+ record.Release()
+
+ return r.ctx.Err()
+ }
+}
+
+func (r *RollingDataWriter) stream(outputDataFilesCh chan<- iceberg.DataFile) {
+ defer r.wg.Done()
+ defer close(r.errorCh)
+
+ recordIter := func(yield func(arrow.RecordBatch, error) bool) {
+ for record := range r.recordCh {
+ if !yield(record, nil) {
+ return
+ }
+ }
+ }
+
+ binPackedRecords := binPackRecords(recordIter, 20,
r.factory.targetFileSize)
+ for batch := range binPackedRecords {
+ if err := r.flushToDataFile(batch, outputDataFilesCh); err !=
nil {
+ select {
+ case r.errorCh <- err:
+ default:
+ }
+
+ return
+ }
+ }
+}
+
+func (r *RollingDataWriter) flushToDataFile(batch []arrow.RecordBatch,
outputDataFilesCh chan<- iceberg.DataFile) error {
+ if len(batch) == 0 {
+ return nil
+ }
+
+ task := iter.Seq[WriteTask](func(yield func(WriteTask) bool) {
+ cnt := int(r.factory.counter.Add(1) - 1)
+
+ yield(WriteTask{
+ Uuid: *r.factory.args.writeUUID,
+ ID: cnt,
+ Schema: r.factory.taskSchema,
+ Batches: batch,
+ })
+ })
+
+ parseDataLoc, err := url.Parse(r.factory.rootLocation)
+ if err != nil {
+ return fmt.Errorf("failed to parse rootLocation: %v", err)
+ }
+
+ partitionMeta := *r.factory.meta
+ if partitionMeta.props == nil {
+ partitionMeta.props = make(map[string]string)
+ }
+ partitionMeta.props[WriteDataPathKey] =
parseDataLoc.JoinPath("data").JoinPath(r.partitionKey).String()
+
+ outputDataFiles := writeFiles(r.ctx, r.factory.rootLocation,
r.factory.args.fs, &partitionMeta, r.partitionValues, task)
+ for dataFile, err := range outputDataFiles {
+ if err != nil {
+ return err
+ }
+ outputDataFilesCh <- dataFile
+ }
+
+ for _, rec := range batch {
+ rec.Release()
+ }
+
+ return nil
+}
+
+func (r *RollingDataWriter) close() {
+ r.cancel()
+ close(r.recordCh)
+}
+
+func (r *RollingDataWriter) closeAndWait() error {
+ r.close()
+ r.factory.writers.Delete(r.partitionKey)
+ r.wg.Wait()
+
+ select {
+ case err := <-r.errorCh:
+ if err != nil {
+ return fmt.Errorf("error in rolling data writer: %w",
err)
+ }
+
+ return nil
+ default:
+
+ return nil
+ }
+}
+
+func (w *writerFactory) closeAll() error {
+ var writers []*RollingDataWriter
+ w.writers.Range(func(key, value any) bool {
+ writer, ok := value.(*RollingDataWriter)
+ if ok {
+ writers = append(writers, writer)
+ }
+
+ return true
+ })
+
+ var err error
+ for _, writer := range writers {
+ if closeErr := writer.closeAndWait(); closeErr != nil && err ==
nil {
+ err = closeErr
+ }
+ }
+
+ return err
+}
diff --git a/table/snapshots_internal_test.go b/table/snapshots_internal_test.go
index d0e15817..796642f9 100644
--- a/table/snapshots_internal_test.go
+++ b/table/snapshots_internal_test.go
@@ -38,7 +38,7 @@ func TestSnapshotSummaryCollector(t *testing.T) {
assert.Equal(t, iceberg.Properties{}, ssc.build())
dataFile, err := iceberg.NewDataFileBuilder(*iceberg.UnpartitionedSpec,
- iceberg.EntryContentData, "/path/to/file.parquet",
iceberg.ParquetFile, nil, 100, 1234)
+ iceberg.EntryContentData, "/path/to/file.parquet",
iceberg.ParquetFile, nil, nil, nil, 100, 1234)
require.NoError(t, err)
require.NoError(t, ssc.addFile(dataFile.Build(), tableSchemaSimple,
*iceberg.UnpartitionedSpec))
@@ -64,11 +64,11 @@ func TestSnapshotSummaryCollectorWithPartition(t
*testing.T) {
dataFile1 := must(iceberg.NewDataFileBuilder(
spec, iceberg.EntryContentData, "/path/to/file1.parquet",
- iceberg.ParquetFile, map[int]any{1001: int32(1)}, 100,
1234)).Build()
+ iceberg.ParquetFile, map[int]any{1001: int32(1)}, nil, nil,
100, 1234)).Build()
dataFile2 := must(iceberg.NewDataFileBuilder(
spec, iceberg.EntryContentData, "/path/to/file2.parquet",
- iceberg.ParquetFile, map[int]any{1001: int32(2)}, 200,
4321)).Build()
+ iceberg.ParquetFile, map[int]any{1001: int32(2)}, nil, nil,
200, 4321)).Build()
ssc.addFile(dataFile1, sc, spec)
ssc.removeFile(dataFile1, sc, spec)
diff --git a/table/writer.go b/table/writer.go
index b77a6c8b..7a9155c2 100644
--- a/table/writer.go
+++ b/table/writer.go
@@ -53,7 +53,7 @@ type writer struct {
meta *MetadataBuilder
}
-func (w *writer) writeFile(ctx context.Context, task WriteTask)
(iceberg.DataFile, error) {
+func (w *writer) writeFile(ctx context.Context, partitionValues map[int]any,
task WriteTask) (iceberg.DataFile, error) {
defer func() {
for _, b := range task.Batches {
b.Release()
@@ -78,15 +78,21 @@ func (w *writer) writeFile(ctx context.Context, task
WriteTask) (iceberg.DataFil
filePath := w.loc.NewDataLocation(
task.GenerateDataFileName("parquet"))
- return w.format.WriteDataFile(ctx, w.fs, internal.WriteFileInfo{
+ currentSpec, err := w.meta.CurrentSpec()
+ if err != nil {
+ return nil, err
+ }
+
+ return w.format.WriteDataFile(ctx, w.fs, partitionValues,
internal.WriteFileInfo{
FileSchema: w.fileSchema,
FileName: filePath,
StatsCols: statsCols,
WriteProps: w.props,
+ Spec: *currentSpec,
}, batches)
}
-func writeFiles(ctx context.Context, rootLocation string, fs io.WriteFileIO,
meta *MetadataBuilder, tasks iter.Seq[WriteTask]) iter.Seq2[iceberg.DataFile,
error] {
+func writeFiles(ctx context.Context, rootLocation string, fs io.WriteFileIO,
meta *MetadataBuilder, partitionValues map[int]any, tasks iter.Seq[WriteTask])
iter.Seq2[iceberg.DataFile, error] {
locProvider, err := LoadLocationProvider(rootLocation, meta.props)
if err != nil {
return func(yield func(iceberg.DataFile, error) bool) {
@@ -122,6 +128,6 @@ func writeFiles(ctx context.Context, rootLocation string,
fs io.WriteFileIO, met
nworkers := config.EnvConfig.MaxWorkers
return internal.MapExec(nworkers, tasks, func(t WriteTask)
(iceberg.DataFile, error) {
- return w.writeFile(ctx, t)
+ return w.writeFile(ctx, partitionValues, t)
})
}
diff --git a/transforms_test.go b/transforms_test.go
index 7f92c29a..92e20513 100644
--- a/transforms_test.go
+++ b/transforms_test.go
@@ -209,6 +209,7 @@ func TestManifestPartitionVals(t *testing.T) {
partitionSpec, iceberg.EntryContentData,
"1234.parquet", iceberg.ParquetFile,
map[int]any{1000: result.Val.Any()},
+ nil, nil,
100, 100_000,
)
require.NoError(t, err)