This is an automated email from the ASF dual-hosted git repository. lostluck pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 2b96716ef5f Implement mongodbio.Read with an SDF (#25160) 2b96716ef5f is described below commit 2b96716ef5f1e575bb53cf3d23843d42faa45ce3 Author: Johanna Öjeling <51084516+johannaojel...@users.noreply.github.com> AuthorDate: Sat Jan 28 01:29:30 2023 +0100 Implement mongodbio.Read with an SDF (#25160) --- sdks/go/pkg/beam/io/mongodbio/coder.go | 25 +- sdks/go/pkg/beam/io/mongodbio/coder_test.go | 135 +++--- sdks/go/pkg/beam/io/mongodbio/common.go | 38 +- .../pkg/beam/io/mongodbio/id_range_restriction.go | 206 +++++++++ .../beam/io/mongodbio/id_range_restriction_test.go | 179 ++++++++ sdks/go/pkg/beam/io/mongodbio/id_range_split.go | 248 +++++++++++ .../pkg/beam/io/mongodbio/id_range_split_test.go | 275 ++++++++++++ sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go | 194 +++++++++ .../pkg/beam/io/mongodbio/id_range_tracker_test.go | 461 +++++++++++++++++++++ sdks/go/pkg/beam/io/mongodbio/read.go | 402 +++++------------- sdks/go/pkg/beam/io/mongodbio/read_test.go | 323 --------------- 11 files changed, 1774 insertions(+), 712 deletions(-) diff --git a/sdks/go/pkg/beam/io/mongodbio/coder.go b/sdks/go/pkg/beam/io/mongodbio/coder.go index c140f9a8a25..3100f0fd93d 100644 --- a/sdks/go/pkg/beam/io/mongodbio/coder.go +++ b/sdks/go/pkg/beam/io/mongodbio/coder.go @@ -26,9 +26,14 @@ import ( func init() { beam.RegisterCoder( - reflect.TypeOf((*bson.M)(nil)).Elem(), - encodeBSONMap, - decodeBSONMap, + reflect.TypeOf((*idRangeRestriction)(nil)).Elem(), + encodeBSON[idRangeRestriction], + decodeBSON[idRangeRestriction], + ) + beam.RegisterCoder( + reflect.TypeOf((*idRange)(nil)).Elem(), + encodeBSON[idRange], + decodeBSON[idRange], ) beam.RegisterCoder( reflect.TypeOf((*primitive.ObjectID)(nil)).Elem(), @@ -37,19 +42,19 @@ func init() { ) } -func encodeBSONMap(m bson.M) ([]byte, error) { - bytes, err := bson.Marshal(m) +func encodeBSON[T any](in T) ([]byte, error) { + out, err := bson.Marshal(in) if err != nil { return nil, fmt.Errorf("error encoding BSON: %w", err) } - return bytes, nil + return out, nil } -func decodeBSONMap(bytes []byte) (bson.M, error) { - var out bson.M - if err := bson.Unmarshal(bytes, &out); err != nil { - return nil, fmt.Errorf("error decoding BSON: %w", err) +func decodeBSON[T any](in []byte) (T, error) { + var out T + if err := bson.Unmarshal(in, &out); err != nil { + return out, fmt.Errorf("error decoding BSON: %w", err) } return out, nil diff --git a/sdks/go/pkg/beam/io/mongodbio/coder_test.go b/sdks/go/pkg/beam/io/mongodbio/coder_test.go index d5e3bb2974d..98f81c50ef2 100644 --- a/sdks/go/pkg/beam/io/mongodbio/coder_test.go +++ b/sdks/go/pkg/beam/io/mongodbio/coder_test.go @@ -23,137 +23,102 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) -func Test_encodeBSONMap(t *testing.T) { +func Test_encodeDecodeBSONMap(t *testing.T) { tests := []struct { - name string - m bson.M - want []byte - wantErr bool + name string + val bson.M }{ { - name: "Encode bson.M", - m: bson.M{"key": "val"}, - want: []byte{18, 0, 0, 0, 2, 107, 101, 121, 0, 4, 0, 0, 0, 118, 97, 108, 0, 0}, - wantErr: false, + name: "Encode/decode bson.M", + val: bson.M{"key": "val"}, }, { - name: "Encode empty bson.M", - m: bson.M{}, - want: []byte{5, 0, 0, 0, 0}, - wantErr: false, - }, - { - name: "Encode nil bson.M", - m: bson.M(nil), - want: []byte{5, 0, 0, 0, 0}, - wantErr: false, - }, - { - name: "Error - invalid bson.M", - m: bson.M{"key": make(chan int)}, - wantErr: true, + name: "Encode/decode empty bson.M", + val: bson.M{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := encodeBSONMap(tt.m) - if (err != nil) != tt.wantErr { - t.Fatalf("encodeBSONMap() error = %v, wantErr %v", err, tt.wantErr) + encoded, err := encodeBSON[bson.M](tt.val) + if err != nil { + t.Fatalf("encodeBSON[bson.M]() error = %v", err) + } + + decoded, err := decodeBSON[bson.M](encoded) + if err != nil { + t.Fatalf("decodeBSON[bson.M]() error = %v", err) } - if !cmp.Equal(got, tt.want) { - t.Errorf("encodeBSONMap() got = %v, want %v", got, tt.want) + if diff := cmp.Diff(tt.val, decoded); diff != "" { + t.Errorf("encode/decode mismatch (-want +got):\n%s", diff) } }) } } -func Test_decodeBSONMap(t *testing.T) { +func Test_encodeDecodeIDRangeRestriction(t *testing.T) { tests := []struct { - name string - bytes []byte - want bson.M - wantErr bool + name string + rest idRangeRestriction }{ { - name: "Decode bson.M", - bytes: []byte{18, 0, 0, 0, 2, 107, 101, 121, 0, 4, 0, 0, 0, 118, 97, 108, 0, 0}, - want: bson.M{"key": "val"}, - wantErr: false, + name: "Encode/decode idRangeRestriction", + rest: idRangeRestriction{ + IDRange: idRange{ + Min: objectIDFromHex(t, "5f1b2c3d4e5f60708090a0b0"), + MinInclusive: true, + Max: objectIDFromHex(t, "5f1b2c3d4e5f60708090a0b9"), + MaxInclusive: true, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 5, + }, }, { - name: "Decode empty bson.M", - bytes: []byte{5, 0, 0, 0, 0}, - want: bson.M{}, - wantErr: false, - }, - { - name: "Error - invalid bson.M", - bytes: []byte{}, - wantErr: true, + name: "Encode/decode empty idRangeRestriction", + rest: idRangeRestriction{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := decodeBSONMap(tt.bytes) - if (err != nil) != tt.wantErr { - t.Fatalf("decodeBSONMap() error = %v, wantErr %v", err, tt.wantErr) + encoded, err := encodeBSON[idRangeRestriction](tt.rest) + if err != nil { + t.Fatalf("encodeBSON[idRangeRestriction]() error = %v", err) + } + + decoded, err := decodeBSON[idRangeRestriction](encoded) + if err != nil { + t.Fatalf("decodeBSON[idRangeRestriction]() error = %v", err) } - if !cmp.Equal(got, tt.want) { - t.Errorf("decodeBSONMap() got = %v, want %v", got, tt.want) + if diff := cmp.Diff(tt.rest, decoded); diff != "" { + t.Errorf("encode/decode mismatch (-want +got):\n%s", diff) } }) } } -func Test_encodeObjectID(t *testing.T) { +func Test_encodeDecodeObjectID(t *testing.T) { tests := []struct { name string objectID primitive.ObjectID - want []byte }{ { - name: "Encode object ID", + name: "Encode/decode object ID", objectID: objectIDFromHex(t, "5f1b2c3d4e5f60708090a0b0"), - want: []byte{95, 27, 44, 61, 78, 95, 96, 112, 128, 144, 160, 176}, }, { - name: "Encode nil object ID", + name: "Encode/decode nil object ID", objectID: primitive.NilObjectID, - want: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := encodeObjectID(tt.objectID); !cmp.Equal(got, tt.want) { - t.Errorf("encodeObjectID() = %v, want %v", got, tt.want) - } - }) - } -} + encoded := encodeObjectID(tt.objectID) + decoded := decodeObjectID(encoded) -func Test_decodeObjectID(t *testing.T) { - tests := []struct { - name string - bytes []byte - want primitive.ObjectID - }{ - { - name: "Decode object ID", - bytes: []byte{95, 27, 44, 61, 78, 95, 96, 112, 128, 144, 160, 176}, - want: objectIDFromHex(t, "5f1b2c3d4e5f60708090a0b0"), - }, - { - name: "Decode nil object ID", - bytes: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - want: primitive.NilObjectID, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := decodeObjectID(tt.bytes); !cmp.Equal(got, tt.want) { - t.Errorf("decodeObjectID() = %v, want %v", got, tt.want) + if !cmp.Equal(decoded, tt.objectID) { + t.Errorf("decodeObjectID() = %v, want %v", decoded, tt.objectID) } }) } diff --git a/sdks/go/pkg/beam/io/mongodbio/common.go b/sdks/go/pkg/beam/io/mongodbio/common.go index 9d6ffbeaa95..e1d0657a206 100644 --- a/sdks/go/pkg/beam/io/mongodbio/common.go +++ b/sdks/go/pkg/beam/io/mongodbio/common.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -38,13 +39,16 @@ type mongoDBFn struct { } func (fn *mongoDBFn) Setup(ctx context.Context) error { - client, err := newClient(ctx, fn.URI) - if err != nil { - return err + if fn.client == nil { + client, err := newClient(ctx, fn.URI) + if err != nil { + return err + } + + fn.client = client } - fn.client = client - fn.collection = client.Database(fn.Database).Collection(fn.Collection) + fn.collection = fn.client.Database(fn.Database).Collection(fn.Collection) return nil } @@ -71,3 +75,27 @@ func (fn *mongoDBFn) Teardown(ctx context.Context) error { return nil } + +type documentID struct { + ID any `bson:"_id"` +} + +func findID( + ctx context.Context, + collection *mongo.Collection, + filter any, + order int, + skip int64, +) (any, error) { + opts := options.FindOne(). + SetProjection(bson.M{"_id": 1}). + SetSort(bson.M{"_id": order}). + SetSkip(skip) + + var docID documentID + if err := collection.FindOne(ctx, filter, opts).Decode(&docID); err != nil { + return nil, err + } + + return docID.ID, nil +} diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_restriction.go b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction.go new file mode 100644 index 00000000000..c527631cd66 --- /dev/null +++ b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction.go @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mongodbio + +import ( + "context" + "errors" + "fmt" + "math" + "reflect" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" +) + +func init() { + beam.RegisterType(reflect.TypeOf((*idRangeRestriction)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*idRange)(nil)).Elem()) +} + +// idRangeRestriction represents a range of document IDs to read from MongoDB. IDRange holds +// information about the minimum and maximum IDs. CustomFilter is the custom filter to apply when +// reading from the collection. Count is the number of documents within the ID range that match the +// custom filter. +type idRangeRestriction struct { + IDRange idRange + CustomFilter bson.M + Count int64 +} + +// newIDRangeRestriction creates a new idRangeRestriction and counts the documents within the ID +// range that match the custom filter. +func newIDRangeRestriction( + ctx context.Context, + collection *mongo.Collection, + idRange idRange, + filter bson.M, +) idRangeRestriction { + mergedFilter := mergeFilters(idRange.Filter(), filter) + + count, err := collection.CountDocuments(ctx, mergedFilter) + if err != nil { + panic(err) + } + + return idRangeRestriction{ + IDRange: idRange, + CustomFilter: filter, + Count: count, + } +} + +// Filter returns a bson.M filter based on the restriction's ID range and custom filter. +func (r idRangeRestriction) Filter() bson.M { + idFilter := r.IDRange.Filter() + return mergeFilters(idFilter, r.CustomFilter) +} + +// mergeFilters merges the ID filter and the custom filter into a single bson.M filter. +func mergeFilters(idFilter bson.M, customFilter bson.M) bson.M { + if len(idFilter) == 0 { + return customFilter + } + + if len(customFilter) == 0 { + return idFilter + } + + return bson.M{ + "$and": []bson.M{idFilter, customFilter}, + } +} + +// SizedSplits divides the restriction into sub-restrictions based on the desired bundle size in +// bytes. +func (r idRangeRestriction) SizedSplits( + ctx context.Context, + collection *mongo.Collection, + bundleSize int64, + useBucketAuto bool, +) ([]idRangeRestriction, error) { + var idRanges []idRange + var err error + + if useBucketAuto { + idRanges, err = bucketAutoSplits(ctx, collection, r.IDRange, bundleSize) + } else { + idRanges, err = splitVectorSplits(ctx, collection, r.IDRange, bundleSize) + } + + if err != nil { + return nil, err + } + + return restrictionsFromIDRanges(ctx, collection, idRanges, r.CustomFilter), err +} + +// FractionSplits divides the restriction into a lower and higher ID sub-restriction based on the +// desired fraction of work the lower piece should be responsible for. +func (r idRangeRestriction) FractionSplits( + ctx context.Context, + collection *mongo.Collection, + fraction float64, +) (lower, higher idRangeRestriction, err error) { + skip := int64(math.Round(float64(r.Count) * fraction)) + + splitID, err := findID(ctx, collection, r.Filter(), 1, skip) + if err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + return idRangeRestriction{}, idRangeRestriction{}, nil + } + + return idRangeRestriction{}, idRangeRestriction{}, fmt.Errorf( + "error finding document ID to split on: %w", + err, + ) + } + + lower = idRangeRestriction{ + IDRange: idRange{ + Min: r.IDRange.Min, + MinInclusive: r.IDRange.MinInclusive, + Max: splitID, + MaxInclusive: false, + }, + CustomFilter: r.CustomFilter, + Count: skip, + } + + higher = idRangeRestriction{ + IDRange: idRange{ + Min: splitID, + MinInclusive: true, + Max: r.IDRange.Max, + MaxInclusive: r.IDRange.MaxInclusive, + }, + CustomFilter: r.CustomFilter, + Count: r.Count - skip, + } + + return lower, higher, nil +} + +// restrictionsFromIDRanges creates a slice of new restrictions based on the ID ranges. +func restrictionsFromIDRanges( + ctx context.Context, + collection *mongo.Collection, + idRanges []idRange, + customFilter bson.M, +) []idRangeRestriction { + restrictions := make([]idRangeRestriction, len(idRanges)) + + for i := 0; i < len(idRanges); i++ { + rest := newIDRangeRestriction( + ctx, + collection, + idRanges[i], + customFilter, + ) + restrictions[i] = rest + } + + return restrictions +} + +// idRange represents a range of document IDs in a MongoDB collection. It stores information about +// the minimum and maximum IDs, and whether they are inclusive or not. +type idRange struct { + Min any + MinInclusive bool + Max any + MaxInclusive bool +} + +// Filter creates a bson.M filter representation of the idRange. +func (i idRange) Filter() bson.M { + filter := make(bson.M, 2) + + if i.MinInclusive { + filter["$gte"] = i.Min + } else { + filter["$gt"] = i.Min + } + + if i.MaxInclusive { + filter["$lte"] = i.Max + } else { + filter["$lt"] = i.Max + } + + return bson.M{"_id": filter} +} diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_restriction_test.go b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction_test.go new file mode 100644 index 00000000000..0534424ef05 --- /dev/null +++ b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction_test.go @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mongodbio + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "go.mongodb.org/mongo-driver/bson" +) + +func Test_mergeFilters(t *testing.T) { + tests := []struct { + name string + idFilter bson.M + filter bson.M + want bson.M + }{ + { + name: "Merge ID filter and custom filter in an $and filter", + idFilter: bson.M{ + "_id": bson.M{ + "$gte": 10, + }, + }, + filter: bson.M{ + "key": bson.M{ + "$ne": "value", + }, + }, + want: bson.M{ + "$and": []bson.M{ + { + "_id": bson.M{ + "$gte": 10, + }, + }, + { + "key": bson.M{ + "$ne": "value", + }, + }, + }, + }, + }, + { + name: "Keep only ID filter when custom filter is empty", + idFilter: bson.M{ + "_id": bson.M{ + "$gte": 10, + }, + }, + filter: bson.M{}, + want: bson.M{ + "_id": bson.M{ + "$gte": 10, + }, + }, + }, + { + name: "Keep only custom filter when ID filter is empty", + idFilter: bson.M{}, + filter: bson.M{ + "key": bson.M{ + "$ne": "value", + }, + }, + want: bson.M{ + "key": bson.M{ + "$ne": "value", + }, + }, + }, + { + name: "Empty filter when both ID filter and custom filter are empty", + idFilter: bson.M{}, + filter: bson.M{}, + want: bson.M{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeFilters(tt.idFilter, tt.filter) + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("mergeFilters() mismatch (-want +got): %v", diff) + } + }) + } +} + +func Test_idRange_Filter(t *testing.T) { + tests := []struct { + name string + idRange idRange + want bson.M + }{ + { + name: "ID filter with $gte when min is inclusive", + idRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 10, + MaxInclusive: false, + }, + want: bson.M{ + "_id": bson.M{ + "$gte": 0, + "$lt": 10, + }, + }, + }, + { + name: "ID filter with $gt when min is exclusive", + idRange: idRange{ + Min: 0, + MinInclusive: false, + Max: 10, + MaxInclusive: false, + }, + want: bson.M{ + "_id": bson.M{ + "$gt": 0, + "$lt": 10, + }, + }, + }, + { + name: "ID filter with $lte when max is inclusive", + idRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 10, + MaxInclusive: true, + }, + want: bson.M{ + "_id": bson.M{ + "$gte": 0, + "$lte": 10, + }, + }, + }, + { + name: "ID filter with $lt when max is exclusive", + idRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 10, + MaxInclusive: false, + }, + want: bson.M{ + "_id": bson.M{ + "$gte": 0, + "$lt": 10, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.idRange.Filter() + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("Filter() mismatch (-want +got): %v", diff) + } + }) + } +} diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_split.go b/sdks/go/pkg/beam/io/mongodbio/id_range_split.go new file mode 100644 index 00000000000..87a6d952866 --- /dev/null +++ b/sdks/go/pkg/beam/io/mongodbio/id_range_split.go @@ -0,0 +1,248 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mongodbio + +import ( + "context" + "fmt" + "math" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readpref" +) + +const ( + maxBucketCount = math.MaxInt32 + minSplitVectorChunkSize = 1024 * 1024 + maxSplitVectorChunkSize = 1024 * 1024 * 1024 +) + +func bucketAutoSplits( + ctx context.Context, + collection *mongo.Collection, + outerRange idRange, + bundleSize int64, +) ([]idRange, error) { + collSize, err := getCollectionSize(ctx, collection) + if err != nil { + return nil, err + } + + bucketCount := calculateBucketCount(collSize, bundleSize) + + buckets, err := getBuckets(ctx, collection, outerRange.Filter(), bucketCount) + if err != nil { + return nil, err + } + + return idRangesFromBuckets(buckets, outerRange), nil +} + +func getCollectionSize(ctx context.Context, collection *mongo.Collection) (int64, error) { + cmd := bson.M{"collStats": collection.Name()} + opts := options.RunCmd().SetReadPreference(readpref.Primary()) + + var stats struct { + Size int64 `bson:"size"` + } + if err := collection.Database().RunCommand(ctx, cmd, opts).Decode(&stats); err != nil { + return 0, fmt.Errorf("error executing collStats command: %w", err) + } + + return stats.Size, nil +} + +func calculateBucketCount(totalSize int64, bundleSize int64) int32 { + if bundleSize < 0 { + panic("monogdbio.calculateBucketCount: bundle size must be greater than 0") + } + + count := totalSize / bundleSize + if totalSize%bundleSize != 0 { + count++ + } + + if count > int64(maxBucketCount) { + count = maxBucketCount + } + + return int32(count) +} + +type bucket struct { + ID minMax `bson:"_id"` +} + +type minMax struct { + Min any `bson:"min"` + Max any `bson:"max"` +} + +func getBuckets( + ctx context.Context, + collection *mongo.Collection, + filter bson.M, + count int32, +) ([]bucket, error) { + pipeline := mongo.Pipeline{ + bson.D{{ + Key: "$match", + Value: filter, + }}, + bson.D{{ + Key: "$bucketAuto", + Value: bson.M{ + "groupBy": "$_id", + "buckets": count, + }, + }}, + } + + opts := options.Aggregate().SetAllowDiskUse(true) + + cursor, err := collection.Aggregate(ctx, pipeline, opts) + if err != nil { + return nil, fmt.Errorf("error executing bucketAuto aggregation: %w", err) + } + + var buckets []bucket + if err := cursor.All(ctx, &buckets); err != nil { + return nil, fmt.Errorf("error decoding buckets: %w", err) + } + + return buckets, nil +} + +func idRangesFromBuckets(buckets []bucket, outerRange idRange) []idRange { + if len(buckets) == 0 { + return nil + } + + ranges := make([]idRange, len(buckets)) + + for i := 0; i < len(buckets); i++ { + subRange := idRange{} + + if i == 0 { + subRange.MinInclusive = outerRange.MinInclusive + subRange.Min = outerRange.Min + } else { + subRange.Min = buckets[i].ID.Min + subRange.MinInclusive = true + } + + if i == len(buckets)-1 { + subRange.Max = outerRange.Max + subRange.MaxInclusive = outerRange.MaxInclusive + } else { + subRange.Max = buckets[i].ID.Max + subRange.MaxInclusive = false + } + + ranges[i] = subRange + } + + return ranges +} + +func splitVectorSplits( + ctx context.Context, + collection *mongo.Collection, + outerRange idRange, + bundleSize int64, +) ([]idRange, error) { + chunkSize := getChunkSize(bundleSize) + + splitKeys, err := getSplitKeys(ctx, collection, outerRange, chunkSize) + if err != nil { + return nil, err + } + + return idRangesFromSplits(splitKeys, outerRange), nil +} + +func getChunkSize(bundleSize int64) int64 { + var chunkSize int64 + + if bundleSize < minSplitVectorChunkSize { + chunkSize = minSplitVectorChunkSize + } else if bundleSize > maxSplitVectorChunkSize { + chunkSize = maxSplitVectorChunkSize + } else { + chunkSize = bundleSize + } + + return chunkSize +} + +func getSplitKeys( + ctx context.Context, + collection *mongo.Collection, + outerRange idRange, + maxChunkSizeBytes int64, +) ([]documentID, error) { + database := collection.Database() + namespace := fmt.Sprintf("%s.%s", database.Name(), collection.Name()) + + cmd := bson.D{ + {Key: "splitVector", Value: namespace}, + {Key: "keyPattern", Value: bson.D{{Key: "_id", Value: 1}}}, + {Key: "min", Value: bson.D{{Key: "_id", Value: outerRange.Min}}}, + {Key: "max", Value: bson.D{{Key: "_id", Value: outerRange.Max}}}, + {Key: "maxChunkSizeBytes", Value: maxChunkSizeBytes}, + } + + opts := options.RunCmd().SetReadPreference(readpref.Primary()) + + var result struct { + SplitKeys []documentID `bson:"splitKeys"` + } + if err := database.RunCommand(ctx, cmd, opts).Decode(&result); err != nil { + return nil, fmt.Errorf("error executing splitVector command: %w", err) + } + + return result.SplitKeys, nil +} + +func idRangesFromSplits(splitKeys []documentID, outerRange idRange) []idRange { + subRanges := make([]idRange, len(splitKeys)+1) + + for i := 0; i < len(splitKeys)+1; i++ { + subRange := idRange{} + + if i == 0 { + subRange.Min = outerRange.Min + subRange.MinInclusive = outerRange.MinInclusive + } else { + subRange.Min = splitKeys[i-1].ID + subRange.MinInclusive = true + } + + if i == len(splitKeys) { + subRange.Max = outerRange.Max + subRange.MaxInclusive = outerRange.MaxInclusive + } else { + subRange.Max = splitKeys[i].ID + subRange.MaxInclusive = false + } + + subRanges[i] = subRange + } + + return subRanges +} diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_split_test.go b/sdks/go/pkg/beam/io/mongodbio/id_range_split_test.go new file mode 100644 index 00000000000..d3e867f6be2 --- /dev/null +++ b/sdks/go/pkg/beam/io/mongodbio/id_range_split_test.go @@ -0,0 +1,275 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mongodbio + +import ( + "math" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func Test_calculateBucketCount(t *testing.T) { + tests := []struct { + name string + totalSize int64 + bundleSize int64 + want int32 + }{ + { + name: "Return ceiling of total size / bundle size", + totalSize: 3 * 1024 * 1024, + bundleSize: 2 * 1024 * 1024, + want: 2, + }, + { + name: "Return max int32 when calculated count is greater than max int32", + totalSize: 1024 * 1024 * 1024 * 1024, + bundleSize: 1, + want: math.MaxInt32, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := calculateBucketCount(tt.totalSize, tt.bundleSize); got != tt.want { + t.Errorf("calculateBucketCount() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_calculateBucketCountPanic(t *testing.T) { + t.Run("Panic when bundleSize is not greater than 0", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("calculateBucketCount() does not panic") + } + }() + + calculateBucketCount(1024, 0) + }) +} + +func Test_idRangesFromBuckets(t *testing.T) { + tests := []struct { + name string + buckets []bucket + outerRange idRange + want []idRange + }{ + { + name: "ID ranges with first element having min ID configuration from outer range, and last element " + + "having max ID configuration from outer range", + buckets: []bucket{ + { + ID: minMax{ + Min: 5, + Max: 100, + }, + }, + { + ID: minMax{ + Min: 100, + Max: 200, + }, + }, + { + ID: minMax{ + Min: 200, + Max: 295, + }, + }, + }, + outerRange: idRange{ + Min: 0, + MinInclusive: false, + Max: 300, + MaxInclusive: false, + }, + want: []idRange{ + { + Min: 0, + MinInclusive: false, + Max: 100, + MaxInclusive: false, + }, + { + Min: 100, + MinInclusive: true, + Max: 200, + MaxInclusive: false, + }, + { + Min: 200, + MinInclusive: true, + Max: 300, + MaxInclusive: false, + }, + }, + }, + { + name: "ID ranges with one element having the same configuration as outer range when there is one " + + "element in buckets", + buckets: []bucket{ + { + ID: minMax{ + Min: 5, + Max: 95, + }, + }, + }, + outerRange: idRange{ + Min: 0, + MinInclusive: false, + Max: 100, + MaxInclusive: false, + }, + want: []idRange{ + { + Min: 0, + MinInclusive: false, + Max: 100, + MaxInclusive: false, + }, + }, + }, + { + name: "Empty ID ranges when there are no elements in buckets", + buckets: nil, + outerRange: idRange{ + Min: 0, + MinInclusive: false, + Max: 100, + MaxInclusive: false, + }, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := idRangesFromBuckets(tt.buckets, tt.outerRange) + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("idRangesFromBuckets() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func Test_getChunkSize(t *testing.T) { + tests := []struct { + name string + bundleSize int64 + want int64 + }{ + { + name: "Return 1 MB if bundle size is less than 1 MB", + bundleSize: 1024, + want: 1024 * 1024, + }, + { + name: "Return 1 GB if bundle size is greater than 1 GB", + bundleSize: 2 * 1024 * 1024 * 1024, + want: 1024 * 1024 * 1024, + }, + { + name: "Return bundle size if bundle size is between 1 MB and 1 GB", + bundleSize: 4 * 1024 * 1024, + want: 4 * 1024 * 1024, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getChunkSize(tt.bundleSize); got != tt.want { + t.Errorf("getChunkSize() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_idRangesFromSplits(t *testing.T) { + tests := []struct { + name string + splitKeys []documentID + outerRange idRange + want []idRange + }{ + { + name: "ID ranges with first element having min ID configuration from outer range, and last element " + + "having max ID configuration from outer range", + splitKeys: []documentID{ + { + ID: 100, + }, + { + ID: 200, + }, + }, + outerRange: idRange{ + Min: 0, + MinInclusive: false, + Max: 300, + MaxInclusive: false, + }, + want: []idRange{ + { + Min: 0, + MinInclusive: false, + Max: 100, + MaxInclusive: false, + }, + { + Min: 100, + MinInclusive: true, + Max: 200, + MaxInclusive: false, + }, + { + Min: 200, + MinInclusive: true, + Max: 300, + MaxInclusive: false, + }, + }, + }, + { + name: "ID ranges with one element having the same configuration as outer range when there are no " + + "elements in key splits", + splitKeys: nil, + outerRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 100, + MaxInclusive: true, + }, + want: []idRange{ + { + Min: 0, + MinInclusive: true, + Max: 100, + MaxInclusive: true, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := idRangesFromSplits(tt.splitKeys, tt.outerRange) + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("idRangesFromSplits() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go new file mode 100644 index 00000000000..6b92ab57d49 --- /dev/null +++ b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go @@ -0,0 +1,194 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mongodbio + +import ( + "context" + "errors" + "fmt" + "reflect" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "go.mongodb.org/mongo-driver/mongo" +) + +func init() { + beam.RegisterType(reflect.TypeOf((*idRangeTracker)(nil))) +} + +// idRangeTracker is a tracker of an idRangeRestriction. +type idRangeTracker struct { + rest idRangeRestriction + collection *mongo.Collection + claimed int64 + claimedID any + stopped bool + err error +} + +// newIDRangeTracker creates a new idRangeTracker tracking the provided idRangeRestriction. +func newIDRangeTracker(rest idRangeRestriction, collection *mongo.Collection) *idRangeTracker { + return &idRangeTracker{ + rest: rest, + collection: collection, + } +} + +// cursorResult holds information about the next document to process from MongoDB. nextID is the ID +// of the document. isExhausted is whether the cursor has been exhausted. +type cursorResult struct { + nextID any + isExhausted bool +} + +// TryClaim accepts a position representing a cursorResult of a document to read from MongoDB. The +// position is successfully claimed if the tracker has not yet completed the work within its +// restriction and the cursor has not been exhausted. +func (rt *idRangeTracker) TryClaim(pos any) (ok bool) { + result, ok := pos.(cursorResult) + if !ok { + rt.err = fmt.Errorf("invalid pos type: %T", pos) + return false + } + + if rt.IsDone() { + return false + } + + if result.isExhausted { + rt.stopped = true + return false + } + + rt.claimed++ + rt.claimedID = result.nextID + + return true +} + +// GetError returns the error associated with the tracker, if any. +func (rt *idRangeTracker) GetError() error { + return rt.err +} + +// TrySplit splits the underlying restriction into a primary and residual restriction based on the +// fraction of remaining work the primary should be responsible for. The restriction may be modified +// as a result of the split. The primary is a copy of the tracker's restriction after the split. +// If the fraction is 1 or all work has already been claimed, returns the full restriction as the +// primary and nil as the residual. If the fraction is 0, stops the tracker, cuts off any remaining +// work from its underlying restriction, and returns a residual representing all remaining work. +// If the fraction is between 0 and 1, attempts to split the remaining work of the underlying +// restriction into two sub-restrictions based on the fraction and assigns them to the primary and +// residual respectively. Returns an error if the split cannot be performed. +func (rt *idRangeTracker) TrySplit(fraction float64) (primary, residual any, err error) { + if fraction < 0 || fraction > 1 { + return nil, nil, errors.New("fraction must be between 0 and 1") + } + + done, remaining := rt.cutRestriction() + + if fraction == 1 || remaining.Count == 0 { + return rt.rest, nil, nil + } + + if fraction == 0 { + rt.rest = done + return rt.rest, remaining, nil + } + + ctx := context.Background() + + primaryRem, resid, err := remaining.FractionSplits(ctx, rt.collection, fraction) + if err != nil { + return nil, nil, err + } + + if resid.Count == 0 { + return rt.rest, nil, nil + } + + if primaryRem.Count == 0 { + rt.rest = done + return rt.rest, remaining, nil + } + + rt.rest.IDRange.Max = primaryRem.IDRange.Max + rt.rest.IDRange.MaxInclusive = primaryRem.IDRange.MaxInclusive + rt.rest.Count -= resid.Count + + return rt.rest, resid, nil +} + +// cutRestriction returns two restrictions: done represents the amount of work from the underlying +// restriction that has already been completed, and remaining represents the amount that remains to +// be processed. Does not modify the underlying restriction. +func (rt *idRangeTracker) cutRestriction() (done idRangeRestriction, remaining idRangeRestriction) { + minRem := rt.claimedID + minInclusiveRem := false + maxInclusiveDone := true + + if minRem == nil { + minRem = rt.rest.IDRange.Min + minInclusiveRem = rt.rest.IDRange.MinInclusive + maxInclusiveDone = false + } + + done = idRangeRestriction{ + IDRange: idRange{ + Min: rt.rest.IDRange.Min, + MinInclusive: rt.rest.IDRange.MinInclusive, + Max: minRem, + MaxInclusive: maxInclusiveDone, + }, + CustomFilter: rt.rest.CustomFilter, + Count: rt.claimed, + } + + remaining = idRangeRestriction{ + IDRange: idRange{ + Min: minRem, + MinInclusive: minInclusiveRem, + Max: rt.rest.IDRange.Max, + MaxInclusive: rt.rest.IDRange.MaxInclusive, + }, + CustomFilter: rt.rest.CustomFilter, + Count: rt.rest.Count - rt.claimed, + } + + return done, remaining +} + +// GetProgress returns the amount of done and remaining work, represented by the count of documents. +func (rt *idRangeTracker) GetProgress() (done float64, remaining float64) { + done = float64(rt.claimed) + remaining = float64(rt.rest.Count - rt.claimed) + return +} + +// IsDone returns true if all work within the tracker's restriction has been completed. +func (rt *idRangeTracker) IsDone() bool { + return rt.stopped || rt.claimed == rt.rest.Count +} + +// GetRestriction returns a copy of the restriction the tracker is tracking. +func (rt *idRangeTracker) GetRestriction() any { + return rt.rest +} + +// IsBounded returns whether the tracker is tracking a restriction with a finite amount of work. +func (*idRangeTracker) IsBounded() bool { + return true +} diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_tracker_test.go b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker_test.go new file mode 100644 index 00000000000..a4484ee1007 --- /dev/null +++ b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker_test.go @@ -0,0 +1,461 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mongodbio + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "go.mongodb.org/mongo-driver/bson" +) + +func Test_idRangeTracker_TryClaim(t *testing.T) { + tests := []struct { + name string + tracker *idRangeTracker + pos any + wantOk bool + wantClaimed int64 + wantClaimedID any + wantDone bool + wantErr bool + }{ + { + name: "Return true when claimed count < total count - 1", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + Count: 10, + }, + claimed: 5, + claimedID: 123, + }, + pos: cursorResult{nextID: 124}, + wantOk: true, + wantClaimed: 6, + wantClaimedID: 124, + wantDone: false, + }, + { + name: "Return true and set to done when claimed count == total count - 1", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + Count: 10, + }, + claimed: 9, + claimedID: 123, + }, + pos: cursorResult{nextID: 124}, + wantOk: true, + wantClaimed: 10, + wantClaimedID: 124, + wantDone: true, + }, + { + name: "Return false and set to done when cursor is exhausted", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + Count: 10, + }, + claimed: 5, + claimedID: 123, + }, + pos: cursorResult{nextID: 124, isExhausted: true}, + wantOk: false, + wantClaimed: 5, + wantClaimedID: 123, + wantDone: true, + }, + { + name: "Return false when claimed count == total count", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + Count: 10, + }, + claimed: 10, + claimedID: 123, + }, + pos: cursorResult{nextID: 124}, + wantOk: false, + wantClaimed: 10, + wantClaimedID: 123, + wantDone: true, + }, + { + name: "Return false when tracker is stopped", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + Count: 10, + }, + claimed: 10, + claimedID: 123, + stopped: true, + }, + pos: cursorResult{nextID: 124}, + wantOk: false, + wantClaimed: 10, + wantClaimedID: 123, + wantDone: true, + }, + { + name: "Return false and set error when pos is of invalid type", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + Count: 10, + }, + claimed: 5, + claimedID: 123, + }, + pos: "invalid", + wantOk: false, + wantClaimed: 5, + wantClaimedID: 123, + wantDone: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotOk := tt.tracker.TryClaim(tt.pos); gotOk != tt.wantOk { + t.Errorf("TryClaim() = %v, want %v", gotOk, tt.wantOk) + } + if gotClaimed := tt.tracker.claimed; gotClaimed != tt.wantClaimed { + t.Errorf("claimed = %v, want %v", gotClaimed, tt.wantClaimed) + } + if gotClaimedID := tt.tracker.claimedID; !cmp.Equal(gotClaimedID, tt.wantClaimedID) { + t.Errorf("claimedID = %v, want %v", gotClaimedID, tt.wantClaimedID) + } + if gotDone := tt.tracker.IsDone(); gotDone != tt.wantDone { + t.Errorf("IsDone() = %v, want %v", gotDone, tt.wantDone) + } + if gotErr := tt.tracker.GetError(); (gotErr != nil) != tt.wantErr { + t.Errorf("GetError() error = %v, wantErr %v", gotErr, tt.wantErr) + } + }) + } +} + +func Test_idRangeTracker_TrySplit(t *testing.T) { + tests := []struct { + name string + tracker *idRangeTracker + fraction float64 + wantPrimary any + wantResidual any + wantErr bool + wantDone bool + }{ + { + name: "Primary contains no more work and residual contains all remaining work when fraction is 0", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 100, + }, + claimed: 70, + claimedID: 69, + }, + fraction: 0, + wantPrimary: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 69, + MaxInclusive: true, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 70, + }, + wantResidual: idRangeRestriction{ + IDRange: idRange{ + Min: 69, + MinInclusive: false, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 30, + }, + wantDone: true, + }, + { + name: "Primary contains all original work and residual is nil when fraction is 1", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 100, + }, + claimed: 70, + claimedID: 69, + }, + fraction: 1, + wantPrimary: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 100, + }, + wantResidual: nil, + wantDone: false, + }, + { + name: "Primary contains all original work and residual is nil when the total count has been claimed", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 100, + }, + claimed: 100, + claimedID: 99, + }, + fraction: 1, + wantPrimary: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 100, + }, + wantResidual: nil, + wantDone: true, + }, + { + name: "Error - fraction is less than 0", + tracker: &idRangeTracker{ + rest: idRangeRestriction{Count: 100}, + }, + fraction: -0.1, + wantErr: true, + }, + { + name: "Error - fraction is greater than 1", + tracker: &idRangeTracker{ + rest: idRangeRestriction{Count: 100}, + }, + fraction: 1.1, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPrimary, gotResidual, err := tt.tracker.TrySplit(tt.fraction) + if (err != nil) != tt.wantErr { + t.Fatalf("TrySplit() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(gotPrimary, tt.wantPrimary); diff != "" { + t.Errorf("TrySplit() gotPrimary mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(gotResidual, tt.wantResidual); diff != "" { + t.Errorf("TrySplit() gotResidual mismatch (-want +got):\n%s", diff) + } + if tt.tracker.IsDone() != tt.wantDone { + t.Errorf("IsDone() = %v, want %v", tt.tracker.IsDone(), tt.wantDone) + } + }) + } +} + +func Test_idRangeTracker_cutRestriction(t *testing.T) { + tests := []struct { + name string + tracker *idRangeTracker + wantDone idRangeRestriction + wantRemaining idRangeRestriction + }{ + { + name: "The tracker's claimedID is used as the max (inclusive) in the done restriction " + + "and as the min (exclusive) in the remaining restriction when claimedID is not nil", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 100, + }, + claimed: 70, + claimedID: 69, + }, + wantDone: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: true, + Max: 69, + MaxInclusive: true, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 70, + }, + wantRemaining: idRangeRestriction{ + IDRange: idRange{ + Min: 69, + MinInclusive: false, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 30, + }, + }, + { + name: "The tracker's restriction's min ID is used as the max (exclusive) in the done restriction " + + "and as the min in the remaining restriction when claimedID is nil", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: false, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 100, + }, + claimed: 0, + claimedID: nil, + }, + wantDone: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: false, + Max: 0, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 0, + }, + wantRemaining: idRangeRestriction{ + IDRange: idRange{ + Min: 0, + MinInclusive: false, + Max: 100, + MaxInclusive: false, + }, + CustomFilter: bson.M{"key": "val"}, + Count: 100, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotDone, gotRemaining := tt.tracker.cutRestriction() + if diff := cmp.Diff(gotDone, tt.wantDone); diff != "" { + t.Errorf("cutRestriction() gotDone mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(gotRemaining, tt.wantRemaining); diff != "" { + t.Errorf("cutRestriction() gotRemaining mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func Test_idRangeTracker_GetProgress(t *testing.T) { + tracker := &idRangeTracker{ + rest: idRangeRestriction{ + Count: 100, + }, + claimed: 30, + } + wantDone := float64(30) + wantRemaining := float64(70) + + t.Run( + "Done is represented by claimed count, and remaining by total count - claimed count", + func(t *testing.T) { + gotDone, gotRemaining := tracker.GetProgress() + if gotDone != wantDone { + t.Errorf("GetProgress() gotDone = %v, want %v", gotDone, wantDone) + } + if gotRemaining != wantRemaining { + t.Errorf("GetProgress() gotRemaining = %v, want %v", gotRemaining, wantRemaining) + } + }, + ) +} + +func Test_idRangeTracker_IsDone(t *testing.T) { + tests := []struct { + name string + tracker *idRangeTracker + want bool + }{ + { + name: "True when the tracker's claimed count is equal to the total count", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + Count: 100, + }, + claimed: 100, + }, + want: true, + }, + { + name: "True when the tracker is stopped", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + Count: 100, + }, + claimed: 95, + stopped: true, + }, + want: true, + }, + { + name: "False when the tracker is not stopped and its claimed count is less than the total count", + tracker: &idRangeTracker{ + rest: idRangeRestriction{ + Count: 100, + }, + claimed: 95, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.tracker.IsDone(); got != tt.want { + t.Errorf("IsDone() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sdks/go/pkg/beam/io/mongodbio/read.go b/sdks/go/pkg/beam/io/mongodbio/read.go index b12a6d7738b..59d8cf6aef9 100644 --- a/sdks/go/pkg/beam/io/mongodbio/read.go +++ b/sdks/go/pkg/beam/io/mongodbio/read.go @@ -17,35 +17,28 @@ package mongodbio import ( "context" + "errors" "fmt" - "math" "reflect" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" "github.com/apache/beam/sdks/v2/go/pkg/beam/log" "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/util/structx" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "go.mongodb.org/mongo-driver/mongo/readpref" ) const ( defaultReadBundleSize = 64 * 1024 * 1024 - - minSplitVectorChunkSize = 1024 * 1024 - maxSplitVectorChunkSize = 1024 * 1024 * 1024 - - maxBucketCount = math.MaxInt32 ) func init() { - register.DoFn3x1[context.Context, []byte, func(bson.M), error](&bucketAutoFn{}) - register.DoFn3x1[context.Context, []byte, func(bson.M), error](&splitVectorFn{}) - register.Emitter1[bson.M]() - - register.DoFn3x1[context.Context, bson.M, func(beam.Y), error](&readFn{}) + register.DoFn4x1[context.Context, *sdf.LockRTracker, []byte, func(beam.Y), error]( + &readFn{}, + ) register.Emitter1[beam.Y]() } @@ -92,338 +85,169 @@ func Read( imp := beam.Impulse(s) - var bundled beam.PCollection - - if option.BucketAuto { - bundled = beam.ParDo(s, newBucketAutoFn(uri, database, collection, option), imp) - } else { - bundled = beam.ParDo(s, newSplitVectorFn(uri, database, collection, option), imp) - } - return beam.ParDo( s, newReadFn(uri, database, collection, t, option), - bundled, + imp, beam.TypeDefinition{Var: beam.YType, T: t}, ) } -type bucketAutoFn struct { +type readFn struct { mongoDBFn + BucketAuto bool BundleSize int64 + Filter []byte + Type beam.EncodedType + filter bson.M + projection bson.D } -func newBucketAutoFn( +func newReadFn( uri string, database string, collection string, + t reflect.Type, option *ReadOption, -) *bucketAutoFn { - return &bucketAutoFn{ +) *readFn { + filter, err := encodeBSON[bson.M](option.Filter) + if err != nil { + panic(fmt.Sprintf("mongodbio.newReadFn: %v", err)) + } + + return &readFn{ mongoDBFn: mongoDBFn{ URI: uri, Database: database, Collection: collection, }, + BucketAuto: option.BucketAuto, BundleSize: option.BundleSize, + Filter: filter, + Type: beam.EncodedType{T: t}, } } -func (fn *bucketAutoFn) ProcessElement( - ctx context.Context, - _ []byte, - emit func(bson.M), -) error { - collectionSize, err := fn.getCollectionSize(ctx) - if err != nil { +func (fn *readFn) Setup(ctx context.Context) error { + var err error + if err = fn.mongoDBFn.Setup(ctx); err != nil { return err } - if collectionSize == 0 { - return nil - } - - bucketCount := calculateBucketCount(collectionSize, fn.BundleSize) - - buckets, err := fn.getBuckets(ctx, bucketCount) + fn.filter, err = decodeBSON[bson.M](fn.Filter) if err != nil { return err } - idFilters := idFiltersFromBuckets(buckets) - - for _, filter := range idFilters { - emit(filter) - } + fn.projection = inferProjection(fn.Type.T, bsonTag) return nil } -type collStats struct { - Size int64 `bson:"size"` -} - -func (fn *bucketAutoFn) getCollectionSize(ctx context.Context) (int64, error) { - cmd := bson.M{"collStats": fn.Collection} - opts := options.RunCmd().SetReadPreference(readpref.Primary()) - - var stats collStats - if err := fn.collection.Database().RunCommand(ctx, cmd, opts).Decode(&stats); err != nil { - return 0, fmt.Errorf("error executing collStats command: %w", err) - } - - return stats.Size, nil -} - -func calculateBucketCount(collectionSize int64, bundleSize int64) int32 { - if bundleSize < 0 { - panic("monogdbio.calculateBucketCount: bundle size must be greater than 0") +func inferProjection(t reflect.Type, tagKey string) bson.D { + names := structx.InferFieldNames(t, tagKey) + if len(names) == 0 { + panic("mongodbio.inferProjection: no names to infer projection from") } - count := collectionSize / bundleSize - if collectionSize%bundleSize != 0 { - count++ - } + projection := make(bson.D, len(names)) - if count > int64(maxBucketCount) { - count = maxBucketCount + for i, name := range names { + projection[i] = bson.E{Key: name, Value: 1} } - return int32(count) -} - -type bucket struct { - ID minMax `bson:"_id"` -} - -type minMax struct { - Min any `bson:"min"` - Max any `bson:"max"` + return projection } -func (fn *bucketAutoFn) getBuckets(ctx context.Context, count int32) ([]bucket, error) { - pipeline := mongo.Pipeline{bson.D{{ - Key: "$bucketAuto", - Value: bson.M{ - "groupBy": "$_id", - "buckets": count, - }, - }}} - - opts := options.Aggregate().SetAllowDiskUse(true) - - cursor, err := fn.collection.Aggregate(ctx, pipeline, opts) - if err != nil { - return nil, fmt.Errorf("error executing bucketAuto aggregation: %w", err) +func (fn *readFn) CreateInitialRestriction(_ []byte) idRangeRestriction { + ctx := context.Background() + if err := fn.Setup(ctx); err != nil { + panic(err) } - var buckets []bucket - if err = cursor.All(ctx, &buckets); err != nil { - return nil, fmt.Errorf("error decoding buckets: %w", err) - } - - return buckets, nil -} - -func idFiltersFromBuckets(buckets []bucket) []bson.M { - idFilters := make([]bson.M, len(buckets)) - - for i := 0; i < len(buckets); i++ { - filter := bson.M{} - - if i != 0 { - filter["$gt"] = buckets[i].ID.Min - } - - if i != len(buckets)-1 { - filter["$lte"] = buckets[i].ID.Max + outerRange, err := findOuterIDRange(ctx, fn.collection, fn.filter) + if err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + log.Infof( + ctx, + "No documents in collection %s.%s match the provided filter", + fn.Database, + fn.Collection, + ) + return idRangeRestriction{} } - if len(filter) == 0 { - idFilters[i] = filter - } else { - idFilters[i] = bson.M{"_id": filter} - } + panic(err) } - return idFilters -} - -type splitVectorFn struct { - mongoDBFn - BundleSize int64 -} - -func newSplitVectorFn( - uri string, - database string, - collection string, - option *ReadOption, -) *splitVectorFn { - return &splitVectorFn{ - mongoDBFn: mongoDBFn{ - URI: uri, - Database: database, - Collection: collection, - }, - BundleSize: option.BundleSize, - } + return newIDRangeRestriction( + ctx, + fn.collection, + outerRange, + fn.filter, + ) } -func (fn *splitVectorFn) ProcessElement( +func findOuterIDRange( ctx context.Context, - _ []byte, - emit func(bson.M), -) error { - chunkSize := getChunkSize(fn.BundleSize) - - splitKeys, err := fn.getSplitKeys(ctx, chunkSize) + collection *mongo.Collection, + filter bson.M, +) (idRange, error) { + minID, err := findID(ctx, collection, filter, 1, 0) if err != nil { - return err + return idRange{}, err } - idFilters := idFiltersFromSplits(splitKeys) - - for _, filter := range idFilters { - emit(filter) + maxID, err := findID(ctx, collection, filter, -1, 0) + if err != nil { + return idRange{}, err } - return nil -} - -func getChunkSize(bundleSize int64) int64 { - var chunkSize int64 - - if bundleSize < minSplitVectorChunkSize { - chunkSize = minSplitVectorChunkSize - } else if bundleSize > maxSplitVectorChunkSize { - chunkSize = maxSplitVectorChunkSize - } else { - chunkSize = bundleSize + outerRange := idRange{ + Min: minID, + MinInclusive: true, + Max: maxID, + MaxInclusive: true, } - return chunkSize -} - -type splitVector struct { - SplitKeys []splitKey `bson:"splitKeys"` -} - -type splitKey struct { - ID any `bson:"_id"` + return outerRange, nil } -func (fn *splitVectorFn) getSplitKeys(ctx context.Context, chunkSize int64) ([]splitKey, error) { - cmd := bson.D{ - {Key: "splitVector", Value: fmt.Sprintf("%s.%s", fn.Database, fn.Collection)}, - {Key: "keyPattern", Value: bson.D{{Key: "_id", Value: 1}}}, - {Key: "maxChunkSizeBytes", Value: chunkSize}, +func (fn *readFn) SplitRestriction(_ []byte, rest idRangeRestriction) []idRangeRestriction { + if rest.Count == 0 { + return []idRangeRestriction{rest} } - opts := options.RunCmd().SetReadPreference(readpref.Primary()) - - var vector splitVector - if err := fn.collection.Database().RunCommand(ctx, cmd, opts).Decode(&vector); err != nil { - return nil, fmt.Errorf("error executing splitVector command: %w", err) + ctx := context.Background() + if err := fn.Setup(ctx); err != nil { + panic(err) } - return vector.SplitKeys, nil -} - -func idFiltersFromSplits(splitKeys []splitKey) []bson.M { - idFilters := make([]bson.M, len(splitKeys)+1) - - for i := 0; i < len(splitKeys)+1; i++ { - filter := bson.M{} - - if i > 0 { - filter["$gt"] = splitKeys[i-1].ID - } - - if i < len(splitKeys) { - filter["$lte"] = splitKeys[i].ID - } - - if len(filter) == 0 { - idFilters[i] = filter - } else { - idFilters[i] = bson.M{"_id": filter} - } - } - - return idFilters -} - -type readFn struct { - mongoDBFn - Filter []byte - Type beam.EncodedType - projection bson.D - filter bson.M -} - -func newReadFn( - uri string, - database string, - collection string, - t reflect.Type, - option *ReadOption, -) *readFn { - filter, err := encodeBSONMap(option.Filter) + splits, err := rest.SizedSplits(ctx, fn.collection, fn.BundleSize, fn.BucketAuto) if err != nil { - panic(fmt.Sprintf("mongodbio.newReadFn: %v", err)) + panic(err) } - return &readFn{ - mongoDBFn: mongoDBFn{ - URI: uri, - Database: database, - Collection: collection, - }, - Filter: filter, - Type: beam.EncodedType{T: t}, - } + return splits } -func (fn *readFn) Setup(ctx context.Context) error { - if err := fn.mongoDBFn.Setup(ctx); err != nil { - return err - } - - filter, err := decodeBSONMap(fn.Filter) - if err != nil { - return err - } - - fn.filter = filter - fn.projection = inferProjection(fn.Type.T, bsonTag) - - return nil +func (fn *readFn) CreateTracker(rest idRangeRestriction) *sdf.LockRTracker { + return sdf.NewLockRTracker(newIDRangeTracker(rest, fn.collection)) } -func inferProjection(t reflect.Type, tagKey string) bson.D { - names := structx.InferFieldNames(t, tagKey) - if len(names) == 0 { - panic("mongodbio.inferProjection: no names to infer projection from") - } - - projection := make(bson.D, len(names)) - - for i, name := range names { - projection[i] = bson.E{Key: name, Value: 1} - } - - return projection +func (fn *readFn) RestrictionSize(_ []byte, rest idRangeRestriction) float64 { + return float64(rest.Count) } func (fn *readFn) ProcessElement( ctx context.Context, - elem bson.M, + rt *sdf.LockRTracker, + _ []byte, emit func(beam.Y), ) (err error) { - mergedFilter := mergeFilters(elem, fn.filter) + rest := rt.GetRestriction().(idRangeRestriction) - cursor, err := fn.findDocuments(ctx, fn.projection, mergedFilter) + cursor, err := fn.getCursor(ctx, rest.Filter()) if err != nil { return err } @@ -442,53 +266,53 @@ func (fn *readFn) ProcessElement( }() for cursor.Next(ctx) { - value, err := decodeDocument(cursor, fn.Type.T) + id, value, err := decodeDocument(cursor, fn.Type.T) if err != nil { return err } - emit(value) - } - - return cursor.Err() -} + result := cursorResult{nextID: id} + if !rt.TryClaim(result) { + return cursor.Err() + } -func mergeFilters(idFilter bson.M, customFilter bson.M) bson.M { - if len(idFilter) == 0 { - return customFilter + emit(value) } - if len(customFilter) == 0 { - return idFilter - } + result := cursorResult{isExhausted: true} + rt.TryClaim(result) - return bson.M{ - "$and": []bson.M{idFilter, customFilter}, - } + return cursor.Err() } -func (fn *readFn) findDocuments( +func (fn *readFn) getCursor( ctx context.Context, - projection bson.D, filter bson.M, ) (*mongo.Cursor, error) { - opts := options.Find().SetProjection(projection) + opts := options.Find(). + SetProjection(fn.projection). + SetSort(bson.M{"_id": 1}) cursor, err := fn.collection.Find(ctx, filter, opts) if err != nil { - return nil, fmt.Errorf("error finding documents: %w", err) + return nil, fmt.Errorf("error executing find command: %w", err) } return cursor, nil } -func decodeDocument(cursor *mongo.Cursor, t reflect.Type) (any, error) { +func decodeDocument(cursor *mongo.Cursor, t reflect.Type) (id any, value any, err error) { + var docID documentID + if err := cursor.Decode(&docID); err != nil { + return nil, nil, fmt.Errorf("error decoding document ID: %w", err) + } + out := reflect.New(t).Interface() if err := cursor.Decode(out); err != nil { - return nil, fmt.Errorf("error decoding document: %w", err) + return nil, nil, fmt.Errorf("error decoding document: %w", err) } - value := reflect.ValueOf(out).Elem().Interface() + value = reflect.ValueOf(out).Elem().Interface() - return value, nil + return docID.ID, value, nil } diff --git a/sdks/go/pkg/beam/io/mongodbio/read_test.go b/sdks/go/pkg/beam/io/mongodbio/read_test.go index 5899457d5a8..666960b17b9 100644 --- a/sdks/go/pkg/beam/io/mongodbio/read_test.go +++ b/sdks/go/pkg/beam/io/mongodbio/read_test.go @@ -16,7 +16,6 @@ package mongodbio import ( - "math" "reflect" "testing" @@ -24,250 +23,6 @@ import ( "go.mongodb.org/mongo-driver/bson" ) -func Test_calculateBucketCount(t *testing.T) { - tests := []struct { - name string - collectionSize int64 - bundleSize int64 - want int32 - }{ - { - name: "Return ceiling of collection size / bundle size", - collectionSize: 3 * 1024 * 1024, - bundleSize: 2 * 1024 * 1024, - want: 2, - }, - { - name: "Return max int32 when calculated count is greater than max int32", - collectionSize: 1024 * 1024 * 1024 * 1024, - bundleSize: 1, - want: math.MaxInt32, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := calculateBucketCount(tt.collectionSize, tt.bundleSize); got != tt.want { - t.Errorf("calculateBucketCount() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_calculateBucketCountPanic(t *testing.T) { - t.Run("Panic when bundleSize is not greater than 0", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("calculateBucketCount() does not panic") - } - }() - - calculateBucketCount(1024, 0) - }) -} - -func Test_idFiltersFromBuckets(t *testing.T) { - tests := []struct { - name string - buckets []bucket - want []bson.M - }{ - { - name: "Create one $lte filter for start range, one $gt filter for end range, and filters with both " + - "$lte and $gt for ranges in between when there are three or more bucket elements", - buckets: []bucket{ - { - ID: minMax{ - Min: objectIDFromHex(t, "6384e03f24f854c1a8ce5378"), - Max: objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - { - ID: minMax{ - Min: objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - Max: objectIDFromHex(t, "6384e03f24f854c1a8ce5382"), - }, - }, - { - ID: minMax{ - Min: objectIDFromHex(t, "6384e03f24f854c1a8ce5382"), - Max: objectIDFromHex(t, "6384e03f24f854c1a8ce5384"), - }, - }, - }, - want: []bson.M{ - { - "_id": bson.M{ - "$lte": objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - { - "_id": bson.M{ - "$gt": objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - "$lte": objectIDFromHex(t, "6384e03f24f854c1a8ce5382"), - }, - }, - { - "_id": bson.M{ - "$gt": objectIDFromHex(t, "6384e03f24f854c1a8ce5382"), - }, - }, - }, - }, - { - name: "Create one $lte filter for start range and one $gt filter for end range when there are two " + - "bucket elements", - buckets: []bucket{ - { - ID: minMax{ - Min: objectIDFromHex(t, "6384e03f24f854c1a8ce5378"), - Max: objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - { - ID: minMax{ - Min: objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - Max: objectIDFromHex(t, "6384e03f24f854c1a8ce5382"), - }, - }, - }, - want: []bson.M{ - { - "_id": bson.M{ - "$lte": objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - { - "_id": bson.M{ - "$gt": objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - }, - }, - { - name: "Create an empty filter when there is one bucket element", - buckets: []bucket{ - { - ID: minMax{ - Min: objectIDFromHex(t, "6384e03f24f854c1a8ce5378"), - Max: objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - }, - want: []bson.M{{}}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := idFiltersFromBuckets(tt.buckets); !cmp.Equal(got, tt.want) { - t.Errorf("idFiltersFromBuckets() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_getChunkSize(t *testing.T) { - tests := []struct { - name string - bundleSize int64 - want int64 - }{ - { - name: "Return 1 MB if bundle size is less than 1 MB", - bundleSize: 1024, - want: 1024 * 1024, - }, - { - name: "Return 1 GB if bundle size is greater than 1 GB", - bundleSize: 2 * 1024 * 1024 * 1024, - want: 1024 * 1024 * 1024, - }, - { - name: "Return bundle size if bundle size is between 1 MB and 1 GB", - bundleSize: 4 * 1024 * 1024, - want: 4 * 1024 * 1024, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getChunkSize(tt.bundleSize); got != tt.want { - t.Errorf("getChunkSize() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_idFiltersFromSplits(t *testing.T) { - tests := []struct { - name string - splitKeys []splitKey - want []bson.M - }{ - { - name: "Create one $lte filter for start range, one $gt filter for end range, and filters with both " + - "$lte and $gt for ranges in between when there are two or more splitKey elements", - splitKeys: []splitKey{ - { - ID: objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - { - ID: objectIDFromHex(t, "6384e03f24f854c1a8ce5382"), - }, - }, - want: []bson.M{ - { - "_id": bson.M{ - "$lte": objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - { - "_id": bson.M{ - "$gt": objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - "$lte": objectIDFromHex(t, "6384e03f24f854c1a8ce5382"), - }, - }, - { - "_id": bson.M{ - "$gt": objectIDFromHex(t, "6384e03f24f854c1a8ce5382"), - }, - }, - }, - }, - { - name: "Create one $lte filter for start range and one $gt filter for end range when there is one " + - "splitKey element", - splitKeys: []splitKey{ - { - ID: objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - want: []bson.M{ - { - "_id": bson.M{ - "$lte": objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - { - "_id": bson.M{ - "$gt": objectIDFromHex(t, "6384e03f24f854c1a8ce5380"), - }, - }, - }, - }, - { - name: "Create an empty filter when there are no splitKey elements", - splitKeys: nil, - want: []bson.M{{}}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := idFiltersFromSplits(tt.splitKeys); !cmp.Equal(got, tt.want) { - t.Errorf("idFiltersFromSplits() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_inferProjection(t *testing.T) { type doc struct { Field1 string `bson:"field1"` @@ -313,81 +68,3 @@ func Test_inferProjectionPanic(t *testing.T) { inferProjection(reflect.TypeOf(doc{}), "bson") }) } - -func Test_mergeFilters(t *testing.T) { - tests := []struct { - name string - idFilter bson.M - filter bson.M - want bson.M - }{ - { - name: "Returned merged ID filter and custom filter in an $and filter", - idFilter: bson.M{ - "_id": bson.M{ - "$gte": 10, - }, - }, - filter: bson.M{ - "key": bson.M{ - "$ne": "value", - }, - }, - want: bson.M{ - "$and": []bson.M{ - { - "_id": bson.M{ - "$gte": 10, - }, - }, - { - "key": bson.M{ - "$ne": "value", - }, - }, - }, - }, - }, - { - name: "Return only ID filter when custom filter is empty", - idFilter: bson.M{ - "_id": bson.M{ - "$gte": 10, - }, - }, - filter: bson.M{}, - want: bson.M{ - "_id": bson.M{ - "$gte": 10, - }, - }, - }, - { - name: "Return only custom filter when ID filter is empty", - idFilter: bson.M{}, - filter: bson.M{ - "key": bson.M{ - "$ne": "value", - }, - }, - want: bson.M{ - "key": bson.M{ - "$ne": "value", - }, - }, - }, - { - name: "Return empty filter when both ID filter and custom filter are empty", - idFilter: bson.M{}, - filter: bson.M{}, - want: bson.M{}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := mergeFilters(tt.idFilter, tt.filter); !cmp.Equal(got, tt.want) { - t.Errorf("mergeFilters() = %v, want %v", got, tt.want) - } - }) - } -}