This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 2b96716ef5f Implement mongodbio.Read with an SDF (#25160)
2b96716ef5f is described below

commit 2b96716ef5f1e575bb53cf3d23843d42faa45ce3
Author: Johanna Öjeling <51084516+johannaojel...@users.noreply.github.com>
AuthorDate: Sat Jan 28 01:29:30 2023 +0100

    Implement mongodbio.Read with an SDF (#25160)
---
 sdks/go/pkg/beam/io/mongodbio/coder.go             |  25 +-
 sdks/go/pkg/beam/io/mongodbio/coder_test.go        | 135 +++---
 sdks/go/pkg/beam/io/mongodbio/common.go            |  38 +-
 .../pkg/beam/io/mongodbio/id_range_restriction.go  | 206 +++++++++
 .../beam/io/mongodbio/id_range_restriction_test.go | 179 ++++++++
 sdks/go/pkg/beam/io/mongodbio/id_range_split.go    | 248 +++++++++++
 .../pkg/beam/io/mongodbio/id_range_split_test.go   | 275 ++++++++++++
 sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go  | 194 +++++++++
 .../pkg/beam/io/mongodbio/id_range_tracker_test.go | 461 +++++++++++++++++++++
 sdks/go/pkg/beam/io/mongodbio/read.go              | 402 +++++-------------
 sdks/go/pkg/beam/io/mongodbio/read_test.go         | 323 ---------------
 11 files changed, 1774 insertions(+), 712 deletions(-)

diff --git a/sdks/go/pkg/beam/io/mongodbio/coder.go 
b/sdks/go/pkg/beam/io/mongodbio/coder.go
index c140f9a8a25..3100f0fd93d 100644
--- a/sdks/go/pkg/beam/io/mongodbio/coder.go
+++ b/sdks/go/pkg/beam/io/mongodbio/coder.go
@@ -26,9 +26,14 @@ import (
 
 func init() {
        beam.RegisterCoder(
-               reflect.TypeOf((*bson.M)(nil)).Elem(),
-               encodeBSONMap,
-               decodeBSONMap,
+               reflect.TypeOf((*idRangeRestriction)(nil)).Elem(),
+               encodeBSON[idRangeRestriction],
+               decodeBSON[idRangeRestriction],
+       )
+       beam.RegisterCoder(
+               reflect.TypeOf((*idRange)(nil)).Elem(),
+               encodeBSON[idRange],
+               decodeBSON[idRange],
        )
        beam.RegisterCoder(
                reflect.TypeOf((*primitive.ObjectID)(nil)).Elem(),
@@ -37,19 +42,19 @@ func init() {
        )
 }
 
-func encodeBSONMap(m bson.M) ([]byte, error) {
-       bytes, err := bson.Marshal(m)
+func encodeBSON[T any](in T) ([]byte, error) {
+       out, err := bson.Marshal(in)
        if err != nil {
                return nil, fmt.Errorf("error encoding BSON: %w", err)
        }
 
-       return bytes, nil
+       return out, nil
 }
 
-func decodeBSONMap(bytes []byte) (bson.M, error) {
-       var out bson.M
-       if err := bson.Unmarshal(bytes, &out); err != nil {
-               return nil, fmt.Errorf("error decoding BSON: %w", err)
+func decodeBSON[T any](in []byte) (T, error) {
+       var out T
+       if err := bson.Unmarshal(in, &out); err != nil {
+               return out, fmt.Errorf("error decoding BSON: %w", err)
        }
 
        return out, nil
diff --git a/sdks/go/pkg/beam/io/mongodbio/coder_test.go 
b/sdks/go/pkg/beam/io/mongodbio/coder_test.go
index d5e3bb2974d..98f81c50ef2 100644
--- a/sdks/go/pkg/beam/io/mongodbio/coder_test.go
+++ b/sdks/go/pkg/beam/io/mongodbio/coder_test.go
@@ -23,137 +23,102 @@ import (
        "go.mongodb.org/mongo-driver/bson/primitive"
 )
 
-func Test_encodeBSONMap(t *testing.T) {
+func Test_encodeDecodeBSONMap(t *testing.T) {
        tests := []struct {
-               name    string
-               m       bson.M
-               want    []byte
-               wantErr bool
+               name string
+               val  bson.M
        }{
                {
-                       name:    "Encode bson.M",
-                       m:       bson.M{"key": "val"},
-                       want:    []byte{18, 0, 0, 0, 2, 107, 101, 121, 0, 4, 0, 
0, 0, 118, 97, 108, 0, 0},
-                       wantErr: false,
+                       name: "Encode/decode bson.M",
+                       val:  bson.M{"key": "val"},
                },
                {
-                       name:    "Encode empty bson.M",
-                       m:       bson.M{},
-                       want:    []byte{5, 0, 0, 0, 0},
-                       wantErr: false,
-               },
-               {
-                       name:    "Encode nil bson.M",
-                       m:       bson.M(nil),
-                       want:    []byte{5, 0, 0, 0, 0},
-                       wantErr: false,
-               },
-               {
-                       name:    "Error - invalid bson.M",
-                       m:       bson.M{"key": make(chan int)},
-                       wantErr: true,
+                       name: "Encode/decode empty bson.M",
+                       val:  bson.M{},
                },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       got, err := encodeBSONMap(tt.m)
-                       if (err != nil) != tt.wantErr {
-                               t.Fatalf("encodeBSONMap() error = %v, wantErr 
%v", err, tt.wantErr)
+                       encoded, err := encodeBSON[bson.M](tt.val)
+                       if err != nil {
+                               t.Fatalf("encodeBSON[bson.M]() error = %v", err)
+                       }
+
+                       decoded, err := decodeBSON[bson.M](encoded)
+                       if err != nil {
+                               t.Fatalf("decodeBSON[bson.M]() error = %v", err)
                        }
 
-                       if !cmp.Equal(got, tt.want) {
-                               t.Errorf("encodeBSONMap() got = %v, want %v", 
got, tt.want)
+                       if diff := cmp.Diff(tt.val, decoded); diff != "" {
+                               t.Errorf("encode/decode mismatch (-want 
+got):\n%s", diff)
                        }
                })
        }
 }
 
-func Test_decodeBSONMap(t *testing.T) {
+func Test_encodeDecodeIDRangeRestriction(t *testing.T) {
        tests := []struct {
-               name    string
-               bytes   []byte
-               want    bson.M
-               wantErr bool
+               name string
+               rest idRangeRestriction
        }{
                {
-                       name:    "Decode bson.M",
-                       bytes:   []byte{18, 0, 0, 0, 2, 107, 101, 121, 0, 4, 0, 
0, 0, 118, 97, 108, 0, 0},
-                       want:    bson.M{"key": "val"},
-                       wantErr: false,
+                       name: "Encode/decode idRangeRestriction",
+                       rest: idRangeRestriction{
+                               IDRange: idRange{
+                                       Min:          objectIDFromHex(t, 
"5f1b2c3d4e5f60708090a0b0"),
+                                       MinInclusive: true,
+                                       Max:          objectIDFromHex(t, 
"5f1b2c3d4e5f60708090a0b9"),
+                                       MaxInclusive: true,
+                               },
+                               CustomFilter: bson.M{"key": "val"},
+                               Count:        5,
+                       },
                },
                {
-                       name:    "Decode empty bson.M",
-                       bytes:   []byte{5, 0, 0, 0, 0},
-                       want:    bson.M{},
-                       wantErr: false,
-               },
-               {
-                       name:    "Error - invalid bson.M",
-                       bytes:   []byte{},
-                       wantErr: true,
+                       name: "Encode/decode empty idRangeRestriction",
+                       rest: idRangeRestriction{},
                },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       got, err := decodeBSONMap(tt.bytes)
-                       if (err != nil) != tt.wantErr {
-                               t.Fatalf("decodeBSONMap() error = %v, wantErr 
%v", err, tt.wantErr)
+                       encoded, err := encodeBSON[idRangeRestriction](tt.rest)
+                       if err != nil {
+                               t.Fatalf("encodeBSON[idRangeRestriction]() 
error = %v", err)
+                       }
+
+                       decoded, err := decodeBSON[idRangeRestriction](encoded)
+                       if err != nil {
+                               t.Fatalf("decodeBSON[idRangeRestriction]() 
error = %v", err)
                        }
 
-                       if !cmp.Equal(got, tt.want) {
-                               t.Errorf("decodeBSONMap() got = %v, want %v", 
got, tt.want)
+                       if diff := cmp.Diff(tt.rest, decoded); diff != "" {
+                               t.Errorf("encode/decode mismatch (-want 
+got):\n%s", diff)
                        }
                })
        }
 }
 
-func Test_encodeObjectID(t *testing.T) {
+func Test_encodeDecodeObjectID(t *testing.T) {
        tests := []struct {
                name     string
                objectID primitive.ObjectID
-               want     []byte
        }{
                {
-                       name:     "Encode object ID",
+                       name:     "Encode/decode object ID",
                        objectID: objectIDFromHex(t, 
"5f1b2c3d4e5f60708090a0b0"),
-                       want:     []byte{95, 27, 44, 61, 78, 95, 96, 112, 128, 
144, 160, 176},
                },
                {
-                       name:     "Encode nil object ID",
+                       name:     "Encode/decode nil object ID",
                        objectID: primitive.NilObjectID,
-                       want:     []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
                },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       if got := encodeObjectID(tt.objectID); !cmp.Equal(got, 
tt.want) {
-                               t.Errorf("encodeObjectID() = %v, want %v", got, 
tt.want)
-                       }
-               })
-       }
-}
+                       encoded := encodeObjectID(tt.objectID)
+                       decoded := decodeObjectID(encoded)
 
-func Test_decodeObjectID(t *testing.T) {
-       tests := []struct {
-               name  string
-               bytes []byte
-               want  primitive.ObjectID
-       }{
-               {
-                       name:  "Decode object ID",
-                       bytes: []byte{95, 27, 44, 61, 78, 95, 96, 112, 128, 
144, 160, 176},
-                       want:  objectIDFromHex(t, "5f1b2c3d4e5f60708090a0b0"),
-               },
-               {
-                       name:  "Decode nil object ID",
-                       bytes: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
-                       want:  primitive.NilObjectID,
-               },
-       }
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       if got := decodeObjectID(tt.bytes); !cmp.Equal(got, 
tt.want) {
-                               t.Errorf("decodeObjectID() = %v, want %v", got, 
tt.want)
+                       if !cmp.Equal(decoded, tt.objectID) {
+                               t.Errorf("decodeObjectID() = %v, want %v", 
decoded, tt.objectID)
                        }
                })
        }
diff --git a/sdks/go/pkg/beam/io/mongodbio/common.go 
b/sdks/go/pkg/beam/io/mongodbio/common.go
index 9d6ffbeaa95..e1d0657a206 100644
--- a/sdks/go/pkg/beam/io/mongodbio/common.go
+++ b/sdks/go/pkg/beam/io/mongodbio/common.go
@@ -20,6 +20,7 @@ import (
        "context"
        "fmt"
 
+       "go.mongodb.org/mongo-driver/bson"
        "go.mongodb.org/mongo-driver/mongo"
        "go.mongodb.org/mongo-driver/mongo/options"
        "go.mongodb.org/mongo-driver/mongo/readpref"
@@ -38,13 +39,16 @@ type mongoDBFn struct {
 }
 
 func (fn *mongoDBFn) Setup(ctx context.Context) error {
-       client, err := newClient(ctx, fn.URI)
-       if err != nil {
-               return err
+       if fn.client == nil {
+               client, err := newClient(ctx, fn.URI)
+               if err != nil {
+                       return err
+               }
+
+               fn.client = client
        }
 
-       fn.client = client
-       fn.collection = client.Database(fn.Database).Collection(fn.Collection)
+       fn.collection = 
fn.client.Database(fn.Database).Collection(fn.Collection)
 
        return nil
 }
@@ -71,3 +75,27 @@ func (fn *mongoDBFn) Teardown(ctx context.Context) error {
 
        return nil
 }
+
+type documentID struct {
+       ID any `bson:"_id"`
+}
+
+func findID(
+       ctx context.Context,
+       collection *mongo.Collection,
+       filter any,
+       order int,
+       skip int64,
+) (any, error) {
+       opts := options.FindOne().
+               SetProjection(bson.M{"_id": 1}).
+               SetSort(bson.M{"_id": order}).
+               SetSkip(skip)
+
+       var docID documentID
+       if err := collection.FindOne(ctx, filter, opts).Decode(&docID); err != 
nil {
+               return nil, err
+       }
+
+       return docID.ID, nil
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_restriction.go 
b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction.go
new file mode 100644
index 00000000000..c527631cd66
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction.go
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "math"
+       "reflect"
+
+       "github.com/apache/beam/sdks/v2/go/pkg/beam"
+       "go.mongodb.org/mongo-driver/bson"
+       "go.mongodb.org/mongo-driver/mongo"
+)
+
+func init() {
+       beam.RegisterType(reflect.TypeOf((*idRangeRestriction)(nil)).Elem())
+       beam.RegisterType(reflect.TypeOf((*idRange)(nil)).Elem())
+}
+
+// idRangeRestriction represents a range of document IDs to read from MongoDB. 
IDRange holds
+// information about the minimum and maximum IDs. CustomFilter is the custom 
filter to apply when
+// reading from the collection. Count is the number of documents within the ID 
range that match the
+// custom filter.
+type idRangeRestriction struct {
+       IDRange      idRange
+       CustomFilter bson.M
+       Count        int64
+}
+
+// newIDRangeRestriction creates a new idRangeRestriction and counts the 
documents within the ID
+// range that match the custom filter.
+func newIDRangeRestriction(
+       ctx context.Context,
+       collection *mongo.Collection,
+       idRange idRange,
+       filter bson.M,
+) idRangeRestriction {
+       mergedFilter := mergeFilters(idRange.Filter(), filter)
+
+       count, err := collection.CountDocuments(ctx, mergedFilter)
+       if err != nil {
+               panic(err)
+       }
+
+       return idRangeRestriction{
+               IDRange:      idRange,
+               CustomFilter: filter,
+               Count:        count,
+       }
+}
+
+// Filter returns a bson.M filter based on the restriction's ID range and 
custom filter.
+func (r idRangeRestriction) Filter() bson.M {
+       idFilter := r.IDRange.Filter()
+       return mergeFilters(idFilter, r.CustomFilter)
+}
+
+// mergeFilters merges the ID filter and the custom filter into a single 
bson.M filter.
+func mergeFilters(idFilter bson.M, customFilter bson.M) bson.M {
+       if len(idFilter) == 0 {
+               return customFilter
+       }
+
+       if len(customFilter) == 0 {
+               return idFilter
+       }
+
+       return bson.M{
+               "$and": []bson.M{idFilter, customFilter},
+       }
+}
+
+// SizedSplits divides the restriction into sub-restrictions based on the 
desired bundle size in
+// bytes.
+func (r idRangeRestriction) SizedSplits(
+       ctx context.Context,
+       collection *mongo.Collection,
+       bundleSize int64,
+       useBucketAuto bool,
+) ([]idRangeRestriction, error) {
+       var idRanges []idRange
+       var err error
+
+       if useBucketAuto {
+               idRanges, err = bucketAutoSplits(ctx, collection, r.IDRange, 
bundleSize)
+       } else {
+               idRanges, err = splitVectorSplits(ctx, collection, r.IDRange, 
bundleSize)
+       }
+
+       if err != nil {
+               return nil, err
+       }
+
+       return restrictionsFromIDRanges(ctx, collection, idRanges, 
r.CustomFilter), err
+}
+
+// FractionSplits divides the restriction into a lower and higher ID 
sub-restriction based on the
+// desired fraction of work the lower piece should be responsible for.
+func (r idRangeRestriction) FractionSplits(
+       ctx context.Context,
+       collection *mongo.Collection,
+       fraction float64,
+) (lower, higher idRangeRestriction, err error) {
+       skip := int64(math.Round(float64(r.Count) * fraction))
+
+       splitID, err := findID(ctx, collection, r.Filter(), 1, skip)
+       if err != nil {
+               if errors.Is(err, mongo.ErrNoDocuments) {
+                       return idRangeRestriction{}, idRangeRestriction{}, nil
+               }
+
+               return idRangeRestriction{}, idRangeRestriction{}, fmt.Errorf(
+                       "error finding document ID to split on: %w",
+                       err,
+               )
+       }
+
+       lower = idRangeRestriction{
+               IDRange: idRange{
+                       Min:          r.IDRange.Min,
+                       MinInclusive: r.IDRange.MinInclusive,
+                       Max:          splitID,
+                       MaxInclusive: false,
+               },
+               CustomFilter: r.CustomFilter,
+               Count:        skip,
+       }
+
+       higher = idRangeRestriction{
+               IDRange: idRange{
+                       Min:          splitID,
+                       MinInclusive: true,
+                       Max:          r.IDRange.Max,
+                       MaxInclusive: r.IDRange.MaxInclusive,
+               },
+               CustomFilter: r.CustomFilter,
+               Count:        r.Count - skip,
+       }
+
+       return lower, higher, nil
+}
+
+// restrictionsFromIDRanges creates a slice of new restrictions based on the 
ID ranges.
+func restrictionsFromIDRanges(
+       ctx context.Context,
+       collection *mongo.Collection,
+       idRanges []idRange,
+       customFilter bson.M,
+) []idRangeRestriction {
+       restrictions := make([]idRangeRestriction, len(idRanges))
+
+       for i := 0; i < len(idRanges); i++ {
+               rest := newIDRangeRestriction(
+                       ctx,
+                       collection,
+                       idRanges[i],
+                       customFilter,
+               )
+               restrictions[i] = rest
+       }
+
+       return restrictions
+}
+
+// idRange represents a range of document IDs in a MongoDB collection. It 
stores information about
+// the minimum and maximum IDs, and whether they are inclusive or not.
+type idRange struct {
+       Min          any
+       MinInclusive bool
+       Max          any
+       MaxInclusive bool
+}
+
+// Filter creates a bson.M filter representation of the idRange.
+func (i idRange) Filter() bson.M {
+       filter := make(bson.M, 2)
+
+       if i.MinInclusive {
+               filter["$gte"] = i.Min
+       } else {
+               filter["$gt"] = i.Min
+       }
+
+       if i.MaxInclusive {
+               filter["$lte"] = i.Max
+       } else {
+               filter["$lt"] = i.Max
+       }
+
+       return bson.M{"_id": filter}
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_restriction_test.go 
b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction_test.go
new file mode 100644
index 00000000000..0534424ef05
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction_test.go
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+       "testing"
+
+       "github.com/google/go-cmp/cmp"
+       "go.mongodb.org/mongo-driver/bson"
+)
+
+func Test_mergeFilters(t *testing.T) {
+       tests := []struct {
+               name     string
+               idFilter bson.M
+               filter   bson.M
+               want     bson.M
+       }{
+               {
+                       name: "Merge ID filter and custom filter in an $and 
filter",
+                       idFilter: bson.M{
+                               "_id": bson.M{
+                                       "$gte": 10,
+                               },
+                       },
+                       filter: bson.M{
+                               "key": bson.M{
+                                       "$ne": "value",
+                               },
+                       },
+                       want: bson.M{
+                               "$and": []bson.M{
+                                       {
+                                               "_id": bson.M{
+                                                       "$gte": 10,
+                                               },
+                                       },
+                                       {
+                                               "key": bson.M{
+                                                       "$ne": "value",
+                                               },
+                                       },
+                               },
+                       },
+               },
+               {
+                       name: "Keep only ID filter when custom filter is empty",
+                       idFilter: bson.M{
+                               "_id": bson.M{
+                                       "$gte": 10,
+                               },
+                       },
+                       filter: bson.M{},
+                       want: bson.M{
+                               "_id": bson.M{
+                                       "$gte": 10,
+                               },
+                       },
+               },
+               {
+                       name:     "Keep only custom filter when ID filter is 
empty",
+                       idFilter: bson.M{},
+                       filter: bson.M{
+                               "key": bson.M{
+                                       "$ne": "value",
+                               },
+                       },
+                       want: bson.M{
+                               "key": bson.M{
+                                       "$ne": "value",
+                               },
+                       },
+               },
+               {
+                       name:     "Empty filter when both ID filter and custom 
filter are empty",
+                       idFilter: bson.M{},
+                       filter:   bson.M{},
+                       want:     bson.M{},
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       got := mergeFilters(tt.idFilter, tt.filter)
+                       if diff := cmp.Diff(got, tt.want); diff != "" {
+                               t.Errorf("mergeFilters() mismatch (-want +got): 
%v", diff)
+                       }
+               })
+       }
+}
+
+func Test_idRange_Filter(t *testing.T) {
+       tests := []struct {
+               name    string
+               idRange idRange
+               want    bson.M
+       }{
+               {
+                       name: "ID filter with $gte when min is inclusive",
+                       idRange: idRange{
+                               Min:          0,
+                               MinInclusive: true,
+                               Max:          10,
+                               MaxInclusive: false,
+                       },
+                       want: bson.M{
+                               "_id": bson.M{
+                                       "$gte": 0,
+                                       "$lt":  10,
+                               },
+                       },
+               },
+               {
+                       name: "ID filter with $gt when min is exclusive",
+                       idRange: idRange{
+                               Min:          0,
+                               MinInclusive: false,
+                               Max:          10,
+                               MaxInclusive: false,
+                       },
+                       want: bson.M{
+                               "_id": bson.M{
+                                       "$gt": 0,
+                                       "$lt": 10,
+                               },
+                       },
+               },
+               {
+                       name: "ID filter with $lte when max is inclusive",
+                       idRange: idRange{
+                               Min:          0,
+                               MinInclusive: true,
+                               Max:          10,
+                               MaxInclusive: true,
+                       },
+                       want: bson.M{
+                               "_id": bson.M{
+                                       "$gte": 0,
+                                       "$lte": 10,
+                               },
+                       },
+               },
+               {
+                       name: "ID filter with $lt when max is exclusive",
+                       idRange: idRange{
+                               Min:          0,
+                               MinInclusive: true,
+                               Max:          10,
+                               MaxInclusive: false,
+                       },
+                       want: bson.M{
+                               "_id": bson.M{
+                                       "$gte": 0,
+                                       "$lt":  10,
+                               },
+                       },
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       got := tt.idRange.Filter()
+                       if diff := cmp.Diff(tt.want, got); diff != "" {
+                               t.Errorf("Filter() mismatch (-want +got): %v", 
diff)
+                       }
+               })
+       }
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_split.go 
b/sdks/go/pkg/beam/io/mongodbio/id_range_split.go
new file mode 100644
index 00000000000..87a6d952866
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_split.go
@@ -0,0 +1,248 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+       "context"
+       "fmt"
+       "math"
+
+       "go.mongodb.org/mongo-driver/bson"
+       "go.mongodb.org/mongo-driver/mongo"
+       "go.mongodb.org/mongo-driver/mongo/options"
+       "go.mongodb.org/mongo-driver/mongo/readpref"
+)
+
+const (
+       maxBucketCount          = math.MaxInt32
+       minSplitVectorChunkSize = 1024 * 1024
+       maxSplitVectorChunkSize = 1024 * 1024 * 1024
+)
+
+func bucketAutoSplits(
+       ctx context.Context,
+       collection *mongo.Collection,
+       outerRange idRange,
+       bundleSize int64,
+) ([]idRange, error) {
+       collSize, err := getCollectionSize(ctx, collection)
+       if err != nil {
+               return nil, err
+       }
+
+       bucketCount := calculateBucketCount(collSize, bundleSize)
+
+       buckets, err := getBuckets(ctx, collection, outerRange.Filter(), 
bucketCount)
+       if err != nil {
+               return nil, err
+       }
+
+       return idRangesFromBuckets(buckets, outerRange), nil
+}
+
+func getCollectionSize(ctx context.Context, collection *mongo.Collection) 
(int64, error) {
+       cmd := bson.M{"collStats": collection.Name()}
+       opts := options.RunCmd().SetReadPreference(readpref.Primary())
+
+       var stats struct {
+               Size int64 `bson:"size"`
+       }
+       if err := collection.Database().RunCommand(ctx, cmd, 
opts).Decode(&stats); err != nil {
+               return 0, fmt.Errorf("error executing collStats command: %w", 
err)
+       }
+
+       return stats.Size, nil
+}
+
+func calculateBucketCount(totalSize int64, bundleSize int64) int32 {
+       if bundleSize < 0 {
+               panic("monogdbio.calculateBucketCount: bundle size must be 
greater than 0")
+       }
+
+       count := totalSize / bundleSize
+       if totalSize%bundleSize != 0 {
+               count++
+       }
+
+       if count > int64(maxBucketCount) {
+               count = maxBucketCount
+       }
+
+       return int32(count)
+}
+
+type bucket struct {
+       ID minMax `bson:"_id"`
+}
+
+type minMax struct {
+       Min any `bson:"min"`
+       Max any `bson:"max"`
+}
+
+func getBuckets(
+       ctx context.Context,
+       collection *mongo.Collection,
+       filter bson.M,
+       count int32,
+) ([]bucket, error) {
+       pipeline := mongo.Pipeline{
+               bson.D{{
+                       Key:   "$match",
+                       Value: filter,
+               }},
+               bson.D{{
+                       Key: "$bucketAuto",
+                       Value: bson.M{
+                               "groupBy": "$_id",
+                               "buckets": count,
+                       },
+               }},
+       }
+
+       opts := options.Aggregate().SetAllowDiskUse(true)
+
+       cursor, err := collection.Aggregate(ctx, pipeline, opts)
+       if err != nil {
+               return nil, fmt.Errorf("error executing bucketAuto aggregation: 
%w", err)
+       }
+
+       var buckets []bucket
+       if err := cursor.All(ctx, &buckets); err != nil {
+               return nil, fmt.Errorf("error decoding buckets: %w", err)
+       }
+
+       return buckets, nil
+}
+
+func idRangesFromBuckets(buckets []bucket, outerRange idRange) []idRange {
+       if len(buckets) == 0 {
+               return nil
+       }
+
+       ranges := make([]idRange, len(buckets))
+
+       for i := 0; i < len(buckets); i++ {
+               subRange := idRange{}
+
+               if i == 0 {
+                       subRange.MinInclusive = outerRange.MinInclusive
+                       subRange.Min = outerRange.Min
+               } else {
+                       subRange.Min = buckets[i].ID.Min
+                       subRange.MinInclusive = true
+               }
+
+               if i == len(buckets)-1 {
+                       subRange.Max = outerRange.Max
+                       subRange.MaxInclusive = outerRange.MaxInclusive
+               } else {
+                       subRange.Max = buckets[i].ID.Max
+                       subRange.MaxInclusive = false
+               }
+
+               ranges[i] = subRange
+       }
+
+       return ranges
+}
+
+func splitVectorSplits(
+       ctx context.Context,
+       collection *mongo.Collection,
+       outerRange idRange,
+       bundleSize int64,
+) ([]idRange, error) {
+       chunkSize := getChunkSize(bundleSize)
+
+       splitKeys, err := getSplitKeys(ctx, collection, outerRange, chunkSize)
+       if err != nil {
+               return nil, err
+       }
+
+       return idRangesFromSplits(splitKeys, outerRange), nil
+}
+
+func getChunkSize(bundleSize int64) int64 {
+       var chunkSize int64
+
+       if bundleSize < minSplitVectorChunkSize {
+               chunkSize = minSplitVectorChunkSize
+       } else if bundleSize > maxSplitVectorChunkSize {
+               chunkSize = maxSplitVectorChunkSize
+       } else {
+               chunkSize = bundleSize
+       }
+
+       return chunkSize
+}
+
+func getSplitKeys(
+       ctx context.Context,
+       collection *mongo.Collection,
+       outerRange idRange,
+       maxChunkSizeBytes int64,
+) ([]documentID, error) {
+       database := collection.Database()
+       namespace := fmt.Sprintf("%s.%s", database.Name(), collection.Name())
+
+       cmd := bson.D{
+               {Key: "splitVector", Value: namespace},
+               {Key: "keyPattern", Value: bson.D{{Key: "_id", Value: 1}}},
+               {Key: "min", Value: bson.D{{Key: "_id", Value: 
outerRange.Min}}},
+               {Key: "max", Value: bson.D{{Key: "_id", Value: 
outerRange.Max}}},
+               {Key: "maxChunkSizeBytes", Value: maxChunkSizeBytes},
+       }
+
+       opts := options.RunCmd().SetReadPreference(readpref.Primary())
+
+       var result struct {
+               SplitKeys []documentID `bson:"splitKeys"`
+       }
+       if err := database.RunCommand(ctx, cmd, opts).Decode(&result); err != 
nil {
+               return nil, fmt.Errorf("error executing splitVector command: 
%w", err)
+       }
+
+       return result.SplitKeys, nil
+}
+
+func idRangesFromSplits(splitKeys []documentID, outerRange idRange) []idRange {
+       subRanges := make([]idRange, len(splitKeys)+1)
+
+       for i := 0; i < len(splitKeys)+1; i++ {
+               subRange := idRange{}
+
+               if i == 0 {
+                       subRange.Min = outerRange.Min
+                       subRange.MinInclusive = outerRange.MinInclusive
+               } else {
+                       subRange.Min = splitKeys[i-1].ID
+                       subRange.MinInclusive = true
+               }
+
+               if i == len(splitKeys) {
+                       subRange.Max = outerRange.Max
+                       subRange.MaxInclusive = outerRange.MaxInclusive
+               } else {
+                       subRange.Max = splitKeys[i].ID
+                       subRange.MaxInclusive = false
+               }
+
+               subRanges[i] = subRange
+       }
+
+       return subRanges
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_split_test.go 
b/sdks/go/pkg/beam/io/mongodbio/id_range_split_test.go
new file mode 100644
index 00000000000..d3e867f6be2
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_split_test.go
@@ -0,0 +1,275 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+       "math"
+       "testing"
+
+       "github.com/google/go-cmp/cmp"
+)
+
+func Test_calculateBucketCount(t *testing.T) {
+       tests := []struct {
+               name       string
+               totalSize  int64
+               bundleSize int64
+               want       int32
+       }{
+               {
+                       name:       "Return ceiling of total size / bundle 
size",
+                       totalSize:  3 * 1024 * 1024,
+                       bundleSize: 2 * 1024 * 1024,
+                       want:       2,
+               },
+               {
+                       name:       "Return max int32 when calculated count is 
greater than max int32",
+                       totalSize:  1024 * 1024 * 1024 * 1024,
+                       bundleSize: 1,
+                       want:       math.MaxInt32,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if got := calculateBucketCount(tt.totalSize, 
tt.bundleSize); got != tt.want {
+                               t.Errorf("calculateBucketCount() = %v, want 
%v", got, tt.want)
+                       }
+               })
+       }
+}
+
+func Test_calculateBucketCountPanic(t *testing.T) {
+       t.Run("Panic when bundleSize is not greater than 0", func(t *testing.T) 
{
+               defer func() {
+                       if r := recover(); r == nil {
+                               t.Errorf("calculateBucketCount() does not 
panic")
+                       }
+               }()
+
+               calculateBucketCount(1024, 0)
+       })
+}
+
+func Test_idRangesFromBuckets(t *testing.T) {
+       tests := []struct {
+               name       string
+               buckets    []bucket
+               outerRange idRange
+               want       []idRange
+       }{
+               {
+                       name: "ID ranges with first element having min ID 
configuration from outer range, and last element " +
+                               "having max ID configuration from outer range",
+                       buckets: []bucket{
+                               {
+                                       ID: minMax{
+                                               Min: 5,
+                                               Max: 100,
+                                       },
+                               },
+                               {
+                                       ID: minMax{
+                                               Min: 100,
+                                               Max: 200,
+                                       },
+                               },
+                               {
+                                       ID: minMax{
+                                               Min: 200,
+                                               Max: 295,
+                                       },
+                               },
+                       },
+                       outerRange: idRange{
+                               Min:          0,
+                               MinInclusive: false,
+                               Max:          300,
+                               MaxInclusive: false,
+                       },
+                       want: []idRange{
+                               {
+                                       Min:          0,
+                                       MinInclusive: false,
+                                       Max:          100,
+                                       MaxInclusive: false,
+                               },
+                               {
+                                       Min:          100,
+                                       MinInclusive: true,
+                                       Max:          200,
+                                       MaxInclusive: false,
+                               },
+                               {
+                                       Min:          200,
+                                       MinInclusive: true,
+                                       Max:          300,
+                                       MaxInclusive: false,
+                               },
+                       },
+               },
+               {
+                       name: "ID ranges with one element having the same 
configuration as outer range when there is one " +
+                               "element in buckets",
+                       buckets: []bucket{
+                               {
+                                       ID: minMax{
+                                               Min: 5,
+                                               Max: 95,
+                                       },
+                               },
+                       },
+                       outerRange: idRange{
+                               Min:          0,
+                               MinInclusive: false,
+                               Max:          100,
+                               MaxInclusive: false,
+                       },
+                       want: []idRange{
+                               {
+                                       Min:          0,
+                                       MinInclusive: false,
+                                       Max:          100,
+                                       MaxInclusive: false,
+                               },
+                       },
+               },
+               {
+                       name:    "Empty ID ranges when there are no elements in 
buckets",
+                       buckets: nil,
+                       outerRange: idRange{
+                               Min:          0,
+                               MinInclusive: false,
+                               Max:          100,
+                               MaxInclusive: false,
+                       },
+                       want: nil,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       got := idRangesFromBuckets(tt.buckets, tt.outerRange)
+                       if diff := cmp.Diff(got, tt.want); diff != "" {
+                               t.Errorf("idRangesFromBuckets() mismatch (-want 
+got):\n%s", diff)
+                       }
+               })
+       }
+}
+
+func Test_getChunkSize(t *testing.T) {
+       tests := []struct {
+               name       string
+               bundleSize int64
+               want       int64
+       }{
+               {
+                       name:       "Return 1 MB if bundle size is less than 1 
MB",
+                       bundleSize: 1024,
+                       want:       1024 * 1024,
+               },
+               {
+                       name:       "Return 1 GB if bundle size is greater than 
1 GB",
+                       bundleSize: 2 * 1024 * 1024 * 1024,
+                       want:       1024 * 1024 * 1024,
+               },
+               {
+                       name:       "Return bundle size if bundle size is 
between 1 MB and 1 GB",
+                       bundleSize: 4 * 1024 * 1024,
+                       want:       4 * 1024 * 1024,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if got := getChunkSize(tt.bundleSize); got != tt.want {
+                               t.Errorf("getChunkSize() = %v, want %v", got, 
tt.want)
+                       }
+               })
+       }
+}
+
+func Test_idRangesFromSplits(t *testing.T) {
+       tests := []struct {
+               name       string
+               splitKeys  []documentID
+               outerRange idRange
+               want       []idRange
+       }{
+               {
+                       name: "ID ranges with first element having min ID 
configuration from outer range, and last element " +
+                               "having max ID configuration from outer range",
+                       splitKeys: []documentID{
+                               {
+                                       ID: 100,
+                               },
+                               {
+                                       ID: 200,
+                               },
+                       },
+                       outerRange: idRange{
+                               Min:          0,
+                               MinInclusive: false,
+                               Max:          300,
+                               MaxInclusive: false,
+                       },
+                       want: []idRange{
+                               {
+                                       Min:          0,
+                                       MinInclusive: false,
+                                       Max:          100,
+                                       MaxInclusive: false,
+                               },
+                               {
+                                       Min:          100,
+                                       MinInclusive: true,
+                                       Max:          200,
+                                       MaxInclusive: false,
+                               },
+                               {
+                                       Min:          200,
+                                       MinInclusive: true,
+                                       Max:          300,
+                                       MaxInclusive: false,
+                               },
+                       },
+               },
+               {
+                       name: "ID ranges with one element having the same 
configuration as outer range when there are no " +
+                               "elements in key splits",
+                       splitKeys: nil,
+                       outerRange: idRange{
+                               Min:          0,
+                               MinInclusive: true,
+                               Max:          100,
+                               MaxInclusive: true,
+                       },
+                       want: []idRange{
+                               {
+                                       Min:          0,
+                                       MinInclusive: true,
+                                       Max:          100,
+                                       MaxInclusive: true,
+                               },
+                       },
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       got := idRangesFromSplits(tt.splitKeys, tt.outerRange)
+                       if diff := cmp.Diff(got, tt.want); diff != "" {
+                               t.Errorf("idRangesFromSplits() mismatch (-want 
+got):\n%s", diff)
+                       }
+               })
+       }
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go 
b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go
new file mode 100644
index 00000000000..6b92ab57d49
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go
@@ -0,0 +1,194 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "reflect"
+
+       "github.com/apache/beam/sdks/v2/go/pkg/beam"
+       "go.mongodb.org/mongo-driver/mongo"
+)
+
+func init() {
+       beam.RegisterType(reflect.TypeOf((*idRangeTracker)(nil)))
+}
+
+// idRangeTracker is a tracker of an idRangeRestriction.
+type idRangeTracker struct {
+       rest       idRangeRestriction
+       collection *mongo.Collection
+       claimed    int64
+       claimedID  any
+       stopped    bool
+       err        error
+}
+
+// newIDRangeTracker creates a new idRangeTracker tracking the provided 
idRangeRestriction.
+func newIDRangeTracker(rest idRangeRestriction, collection *mongo.Collection) 
*idRangeTracker {
+       return &idRangeTracker{
+               rest:       rest,
+               collection: collection,
+       }
+}
+
+// cursorResult holds information about the next document to process from 
MongoDB. nextID is the ID
+// of the document. isExhausted is whether the cursor has been exhausted.
+type cursorResult struct {
+       nextID      any
+       isExhausted bool
+}
+
+// TryClaim accepts a position representing a cursorResult of a document to 
read from MongoDB. The
+// position is successfully claimed if the tracker has not yet completed the 
work within its
+// restriction and the cursor has not been exhausted.
+func (rt *idRangeTracker) TryClaim(pos any) (ok bool) {
+       result, ok := pos.(cursorResult)
+       if !ok {
+               rt.err = fmt.Errorf("invalid pos type: %T", pos)
+               return false
+       }
+
+       if rt.IsDone() {
+               return false
+       }
+
+       if result.isExhausted {
+               rt.stopped = true
+               return false
+       }
+
+       rt.claimed++
+       rt.claimedID = result.nextID
+
+       return true
+}
+
+// GetError returns the error associated with the tracker, if any.
+func (rt *idRangeTracker) GetError() error {
+       return rt.err
+}
+
+// TrySplit splits the underlying restriction into a primary and residual 
restriction based on the
+// fraction of remaining work the primary should be responsible for. The 
restriction may be modified
+// as a result of the split. The primary is a copy of the tracker's 
restriction after the split.
+// If the fraction is 1 or all work has already been claimed, returns the full 
restriction as the
+// primary and nil as the residual. If the fraction is 0, stops the tracker, 
cuts off any remaining
+// work from its underlying restriction, and returns a residual representing 
all remaining work.
+// If the fraction is between 0 and 1, attempts to split the remaining work of 
the underlying
+// restriction into two sub-restrictions based on the fraction and assigns 
them to the primary and
+// residual respectively. Returns an error if the split cannot be performed.
+func (rt *idRangeTracker) TrySplit(fraction float64) (primary, residual any, 
err error) {
+       if fraction < 0 || fraction > 1 {
+               return nil, nil, errors.New("fraction must be between 0 and 1")
+       }
+
+       done, remaining := rt.cutRestriction()
+
+       if fraction == 1 || remaining.Count == 0 {
+               return rt.rest, nil, nil
+       }
+
+       if fraction == 0 {
+               rt.rest = done
+               return rt.rest, remaining, nil
+       }
+
+       ctx := context.Background()
+
+       primaryRem, resid, err := remaining.FractionSplits(ctx, rt.collection, 
fraction)
+       if err != nil {
+               return nil, nil, err
+       }
+
+       if resid.Count == 0 {
+               return rt.rest, nil, nil
+       }
+
+       if primaryRem.Count == 0 {
+               rt.rest = done
+               return rt.rest, remaining, nil
+       }
+
+       rt.rest.IDRange.Max = primaryRem.IDRange.Max
+       rt.rest.IDRange.MaxInclusive = primaryRem.IDRange.MaxInclusive
+       rt.rest.Count -= resid.Count
+
+       return rt.rest, resid, nil
+}
+
+// cutRestriction returns two restrictions: done represents the amount of work 
from the underlying
+// restriction that has already been completed, and remaining represents the 
amount that remains to
+// be processed. Does not modify the underlying restriction.
+func (rt *idRangeTracker) cutRestriction() (done idRangeRestriction, remaining 
idRangeRestriction) {
+       minRem := rt.claimedID
+       minInclusiveRem := false
+       maxInclusiveDone := true
+
+       if minRem == nil {
+               minRem = rt.rest.IDRange.Min
+               minInclusiveRem = rt.rest.IDRange.MinInclusive
+               maxInclusiveDone = false
+       }
+
+       done = idRangeRestriction{
+               IDRange: idRange{
+                       Min:          rt.rest.IDRange.Min,
+                       MinInclusive: rt.rest.IDRange.MinInclusive,
+                       Max:          minRem,
+                       MaxInclusive: maxInclusiveDone,
+               },
+               CustomFilter: rt.rest.CustomFilter,
+               Count:        rt.claimed,
+       }
+
+       remaining = idRangeRestriction{
+               IDRange: idRange{
+                       Min:          minRem,
+                       MinInclusive: minInclusiveRem,
+                       Max:          rt.rest.IDRange.Max,
+                       MaxInclusive: rt.rest.IDRange.MaxInclusive,
+               },
+               CustomFilter: rt.rest.CustomFilter,
+               Count:        rt.rest.Count - rt.claimed,
+       }
+
+       return done, remaining
+}
+
+// GetProgress returns the amount of done and remaining work, represented by 
the count of documents.
+func (rt *idRangeTracker) GetProgress() (done float64, remaining float64) {
+       done = float64(rt.claimed)
+       remaining = float64(rt.rest.Count - rt.claimed)
+       return
+}
+
+// IsDone returns true if all work within the tracker's restriction has been 
completed.
+func (rt *idRangeTracker) IsDone() bool {
+       return rt.stopped || rt.claimed == rt.rest.Count
+}
+
+// GetRestriction returns a copy of the restriction the tracker is tracking.
+func (rt *idRangeTracker) GetRestriction() any {
+       return rt.rest
+}
+
+// IsBounded returns whether the tracker is tracking a restriction with a 
finite amount of work.
+func (*idRangeTracker) IsBounded() bool {
+       return true
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_tracker_test.go 
b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker_test.go
new file mode 100644
index 00000000000..a4484ee1007
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker_test.go
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+       "testing"
+
+       "github.com/google/go-cmp/cmp"
+       "go.mongodb.org/mongo-driver/bson"
+)
+
+func Test_idRangeTracker_TryClaim(t *testing.T) {
+       tests := []struct {
+               name          string
+               tracker       *idRangeTracker
+               pos           any
+               wantOk        bool
+               wantClaimed   int64
+               wantClaimedID any
+               wantDone      bool
+               wantErr       bool
+       }{
+               {
+                       name: "Return true when claimed count < total count - 
1",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       Count: 10,
+                               },
+                               claimed:   5,
+                               claimedID: 123,
+                       },
+                       pos:           cursorResult{nextID: 124},
+                       wantOk:        true,
+                       wantClaimed:   6,
+                       wantClaimedID: 124,
+                       wantDone:      false,
+               },
+               {
+                       name: "Return true and set to done when claimed count 
== total count - 1",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       Count: 10,
+                               },
+                               claimed:   9,
+                               claimedID: 123,
+                       },
+                       pos:           cursorResult{nextID: 124},
+                       wantOk:        true,
+                       wantClaimed:   10,
+                       wantClaimedID: 124,
+                       wantDone:      true,
+               },
+               {
+                       name: "Return false and set to done when cursor is 
exhausted",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       Count: 10,
+                               },
+                               claimed:   5,
+                               claimedID: 123,
+                       },
+                       pos:           cursorResult{nextID: 124, isExhausted: 
true},
+                       wantOk:        false,
+                       wantClaimed:   5,
+                       wantClaimedID: 123,
+                       wantDone:      true,
+               },
+               {
+                       name: "Return false when claimed count == total count",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       Count: 10,
+                               },
+                               claimed:   10,
+                               claimedID: 123,
+                       },
+                       pos:           cursorResult{nextID: 124},
+                       wantOk:        false,
+                       wantClaimed:   10,
+                       wantClaimedID: 123,
+                       wantDone:      true,
+               },
+               {
+                       name: "Return false when tracker is stopped",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       Count: 10,
+                               },
+                               claimed:   10,
+                               claimedID: 123,
+                               stopped:   true,
+                       },
+                       pos:           cursorResult{nextID: 124},
+                       wantOk:        false,
+                       wantClaimed:   10,
+                       wantClaimedID: 123,
+                       wantDone:      true,
+               },
+               {
+                       name: "Return false and set error when pos is of 
invalid type",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       Count: 10,
+                               },
+                               claimed:   5,
+                               claimedID: 123,
+                       },
+                       pos:           "invalid",
+                       wantOk:        false,
+                       wantClaimed:   5,
+                       wantClaimedID: 123,
+                       wantDone:      false,
+                       wantErr:       true,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if gotOk := tt.tracker.TryClaim(tt.pos); gotOk != 
tt.wantOk {
+                               t.Errorf("TryClaim() = %v, want %v", gotOk, 
tt.wantOk)
+                       }
+                       if gotClaimed := tt.tracker.claimed; gotClaimed != 
tt.wantClaimed {
+                               t.Errorf("claimed = %v, want %v", gotClaimed, 
tt.wantClaimed)
+                       }
+                       if gotClaimedID := tt.tracker.claimedID; 
!cmp.Equal(gotClaimedID, tt.wantClaimedID) {
+                               t.Errorf("claimedID = %v, want %v", 
gotClaimedID, tt.wantClaimedID)
+                       }
+                       if gotDone := tt.tracker.IsDone(); gotDone != 
tt.wantDone {
+                               t.Errorf("IsDone() = %v, want %v", gotDone, 
tt.wantDone)
+                       }
+                       if gotErr := tt.tracker.GetError(); (gotErr != nil) != 
tt.wantErr {
+                               t.Errorf("GetError() error = %v, wantErr %v", 
gotErr, tt.wantErr)
+                       }
+               })
+       }
+}
+
+func Test_idRangeTracker_TrySplit(t *testing.T) {
+       tests := []struct {
+               name         string
+               tracker      *idRangeTracker
+               fraction     float64
+               wantPrimary  any
+               wantResidual any
+               wantErr      bool
+               wantDone     bool
+       }{
+               {
+                       name: "Primary contains no more work and residual 
contains all remaining work when fraction is 0",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       IDRange: idRange{
+                                               Min:          0,
+                                               MinInclusive: true,
+                                               Max:          100,
+                                               MaxInclusive: false,
+                                       },
+                                       CustomFilter: bson.M{"key": "val"},
+                                       Count:        100,
+                               },
+                               claimed:   70,
+                               claimedID: 69,
+                       },
+                       fraction: 0,
+                       wantPrimary: idRangeRestriction{
+                               IDRange: idRange{
+                                       Min:          0,
+                                       MinInclusive: true,
+                                       Max:          69,
+                                       MaxInclusive: true,
+                               },
+                               CustomFilter: bson.M{"key": "val"},
+                               Count:        70,
+                       },
+                       wantResidual: idRangeRestriction{
+                               IDRange: idRange{
+                                       Min:          69,
+                                       MinInclusive: false,
+                                       Max:          100,
+                                       MaxInclusive: false,
+                               },
+                               CustomFilter: bson.M{"key": "val"},
+                               Count:        30,
+                       },
+                       wantDone: true,
+               },
+               {
+                       name: "Primary contains all original work and residual 
is nil when fraction is 1",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       IDRange: idRange{
+                                               Min:          0,
+                                               MinInclusive: true,
+                                               Max:          100,
+                                               MaxInclusive: false,
+                                       },
+                                       CustomFilter: bson.M{"key": "val"},
+                                       Count:        100,
+                               },
+                               claimed:   70,
+                               claimedID: 69,
+                       },
+                       fraction: 1,
+                       wantPrimary: idRangeRestriction{
+                               IDRange: idRange{
+                                       Min:          0,
+                                       MinInclusive: true,
+                                       Max:          100,
+                                       MaxInclusive: false,
+                               },
+                               CustomFilter: bson.M{"key": "val"},
+                               Count:        100,
+                       },
+                       wantResidual: nil,
+                       wantDone:     false,
+               },
+               {
+                       name: "Primary contains all original work and residual 
is nil when the total count has been claimed",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       IDRange: idRange{
+                                               Min:          0,
+                                               MinInclusive: true,
+                                               Max:          100,
+                                               MaxInclusive: false,
+                                       },
+                                       CustomFilter: bson.M{"key": "val"},
+                                       Count:        100,
+                               },
+                               claimed:   100,
+                               claimedID: 99,
+                       },
+                       fraction: 1,
+                       wantPrimary: idRangeRestriction{
+                               IDRange: idRange{
+                                       Min:          0,
+                                       MinInclusive: true,
+                                       Max:          100,
+                                       MaxInclusive: false,
+                               },
+                               CustomFilter: bson.M{"key": "val"},
+                               Count:        100,
+                       },
+                       wantResidual: nil,
+                       wantDone:     true,
+               },
+               {
+                       name: "Error - fraction is less than 0",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{Count: 100},
+                       },
+                       fraction: -0.1,
+                       wantErr:  true,
+               },
+               {
+                       name: "Error - fraction is greater than 1",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{Count: 100},
+                       },
+                       fraction: 1.1,
+                       wantErr:  true,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       gotPrimary, gotResidual, err := 
tt.tracker.TrySplit(tt.fraction)
+                       if (err != nil) != tt.wantErr {
+                               t.Fatalf("TrySplit() error = %v, wantErr %v", 
err, tt.wantErr)
+                       }
+                       if diff := cmp.Diff(gotPrimary, tt.wantPrimary); diff 
!= "" {
+                               t.Errorf("TrySplit() gotPrimary mismatch (-want 
+got):\n%s", diff)
+                       }
+                       if diff := cmp.Diff(gotResidual, tt.wantResidual); diff 
!= "" {
+                               t.Errorf("TrySplit() gotResidual mismatch 
(-want +got):\n%s", diff)
+                       }
+                       if tt.tracker.IsDone() != tt.wantDone {
+                               t.Errorf("IsDone() = %v, want %v", 
tt.tracker.IsDone(), tt.wantDone)
+                       }
+               })
+       }
+}
+
+func Test_idRangeTracker_cutRestriction(t *testing.T) {
+       tests := []struct {
+               name          string
+               tracker       *idRangeTracker
+               wantDone      idRangeRestriction
+               wantRemaining idRangeRestriction
+       }{
+               {
+                       name: "The tracker's claimedID is used as the max 
(inclusive) in the done restriction " +
+                               "and as the min (exclusive) in the remaining 
restriction when claimedID is not nil",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       IDRange: idRange{
+                                               Min:          0,
+                                               MinInclusive: true,
+                                               Max:          100,
+                                               MaxInclusive: false,
+                                       },
+                                       CustomFilter: bson.M{"key": "val"},
+                                       Count:        100,
+                               },
+                               claimed:   70,
+                               claimedID: 69,
+                       },
+                       wantDone: idRangeRestriction{
+                               IDRange: idRange{
+                                       Min:          0,
+                                       MinInclusive: true,
+                                       Max:          69,
+                                       MaxInclusive: true,
+                               },
+                               CustomFilter: bson.M{"key": "val"},
+                               Count:        70,
+                       },
+                       wantRemaining: idRangeRestriction{
+                               IDRange: idRange{
+                                       Min:          69,
+                                       MinInclusive: false,
+                                       Max:          100,
+                                       MaxInclusive: false,
+                               },
+                               CustomFilter: bson.M{"key": "val"},
+                               Count:        30,
+                       },
+               },
+               {
+                       name: "The tracker's restriction's min ID is used as 
the max (exclusive) in the done restriction " +
+                               "and as the min in the remaining restriction 
when claimedID is nil",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       IDRange: idRange{
+                                               Min:          0,
+                                               MinInclusive: false,
+                                               Max:          100,
+                                               MaxInclusive: false,
+                                       },
+                                       CustomFilter: bson.M{"key": "val"},
+                                       Count:        100,
+                               },
+                               claimed:   0,
+                               claimedID: nil,
+                       },
+                       wantDone: idRangeRestriction{
+                               IDRange: idRange{
+                                       Min:          0,
+                                       MinInclusive: false,
+                                       Max:          0,
+                                       MaxInclusive: false,
+                               },
+                               CustomFilter: bson.M{"key": "val"},
+                               Count:        0,
+                       },
+                       wantRemaining: idRangeRestriction{
+                               IDRange: idRange{
+                                       Min:          0,
+                                       MinInclusive: false,
+                                       Max:          100,
+                                       MaxInclusive: false,
+                               },
+                               CustomFilter: bson.M{"key": "val"},
+                               Count:        100,
+                       },
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       gotDone, gotRemaining := tt.tracker.cutRestriction()
+                       if diff := cmp.Diff(gotDone, tt.wantDone); diff != "" {
+                               t.Errorf("cutRestriction() gotDone mismatch 
(-want +got):\n%s", diff)
+                       }
+                       if diff := cmp.Diff(gotRemaining, tt.wantRemaining); 
diff != "" {
+                               t.Errorf("cutRestriction() gotRemaining 
mismatch (-want +got):\n%s", diff)
+                       }
+               })
+       }
+}
+
+func Test_idRangeTracker_GetProgress(t *testing.T) {
+       tracker := &idRangeTracker{
+               rest: idRangeRestriction{
+                       Count: 100,
+               },
+               claimed: 30,
+       }
+       wantDone := float64(30)
+       wantRemaining := float64(70)
+
+       t.Run(
+               "Done is represented by claimed count, and remaining by total 
count - claimed count",
+               func(t *testing.T) {
+                       gotDone, gotRemaining := tracker.GetProgress()
+                       if gotDone != wantDone {
+                               t.Errorf("GetProgress() gotDone = %v, want %v", 
gotDone, wantDone)
+                       }
+                       if gotRemaining != wantRemaining {
+                               t.Errorf("GetProgress() gotRemaining = %v, want 
%v", gotRemaining, wantRemaining)
+                       }
+               },
+       )
+}
+
+func Test_idRangeTracker_IsDone(t *testing.T) {
+       tests := []struct {
+               name    string
+               tracker *idRangeTracker
+               want    bool
+       }{
+               {
+                       name: "True when the tracker's claimed count is equal 
to the total count",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       Count: 100,
+                               },
+                               claimed: 100,
+                       },
+                       want: true,
+               },
+               {
+                       name: "True when the tracker is stopped",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       Count: 100,
+                               },
+                               claimed: 95,
+                               stopped: true,
+                       },
+                       want: true,
+               },
+               {
+                       name: "False when the tracker is not stopped and its 
claimed count is less than the total count",
+                       tracker: &idRangeTracker{
+                               rest: idRangeRestriction{
+                                       Count: 100,
+                               },
+                               claimed: 95,
+                       },
+                       want: false,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if got := tt.tracker.IsDone(); got != tt.want {
+                               t.Errorf("IsDone() = %v, want %v", got, tt.want)
+                       }
+               })
+       }
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/read.go 
b/sdks/go/pkg/beam/io/mongodbio/read.go
index b12a6d7738b..59d8cf6aef9 100644
--- a/sdks/go/pkg/beam/io/mongodbio/read.go
+++ b/sdks/go/pkg/beam/io/mongodbio/read.go
@@ -17,35 +17,28 @@ package mongodbio
 
 import (
        "context"
+       "errors"
        "fmt"
-       "math"
        "reflect"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/log"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/register"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/util/structx"
        "go.mongodb.org/mongo-driver/bson"
        "go.mongodb.org/mongo-driver/mongo"
        "go.mongodb.org/mongo-driver/mongo/options"
-       "go.mongodb.org/mongo-driver/mongo/readpref"
 )
 
 const (
        defaultReadBundleSize = 64 * 1024 * 1024
-
-       minSplitVectorChunkSize = 1024 * 1024
-       maxSplitVectorChunkSize = 1024 * 1024 * 1024
-
-       maxBucketCount = math.MaxInt32
 )
 
 func init() {
-       register.DoFn3x1[context.Context, []byte, func(bson.M), 
error](&bucketAutoFn{})
-       register.DoFn3x1[context.Context, []byte, func(bson.M), 
error](&splitVectorFn{})
-       register.Emitter1[bson.M]()
-
-       register.DoFn3x1[context.Context, bson.M, func(beam.Y), 
error](&readFn{})
+       register.DoFn4x1[context.Context, *sdf.LockRTracker, []byte, 
func(beam.Y), error](
+               &readFn{},
+       )
        register.Emitter1[beam.Y]()
 }
 
@@ -92,338 +85,169 @@ func Read(
 
        imp := beam.Impulse(s)
 
-       var bundled beam.PCollection
-
-       if option.BucketAuto {
-               bundled = beam.ParDo(s, newBucketAutoFn(uri, database, 
collection, option), imp)
-       } else {
-               bundled = beam.ParDo(s, newSplitVectorFn(uri, database, 
collection, option), imp)
-       }
-
        return beam.ParDo(
                s,
                newReadFn(uri, database, collection, t, option),
-               bundled,
+               imp,
                beam.TypeDefinition{Var: beam.YType, T: t},
        )
 }
 
-type bucketAutoFn struct {
+type readFn struct {
        mongoDBFn
+       BucketAuto bool
        BundleSize int64
+       Filter     []byte
+       Type       beam.EncodedType
+       filter     bson.M
+       projection bson.D
 }
 
-func newBucketAutoFn(
+func newReadFn(
        uri string,
        database string,
        collection string,
+       t reflect.Type,
        option *ReadOption,
-) *bucketAutoFn {
-       return &bucketAutoFn{
+) *readFn {
+       filter, err := encodeBSON[bson.M](option.Filter)
+       if err != nil {
+               panic(fmt.Sprintf("mongodbio.newReadFn: %v", err))
+       }
+
+       return &readFn{
                mongoDBFn: mongoDBFn{
                        URI:        uri,
                        Database:   database,
                        Collection: collection,
                },
+               BucketAuto: option.BucketAuto,
                BundleSize: option.BundleSize,
+               Filter:     filter,
+               Type:       beam.EncodedType{T: t},
        }
 }
 
-func (fn *bucketAutoFn) ProcessElement(
-       ctx context.Context,
-       _ []byte,
-       emit func(bson.M),
-) error {
-       collectionSize, err := fn.getCollectionSize(ctx)
-       if err != nil {
+func (fn *readFn) Setup(ctx context.Context) error {
+       var err error
+       if err = fn.mongoDBFn.Setup(ctx); err != nil {
                return err
        }
 
-       if collectionSize == 0 {
-               return nil
-       }
-
-       bucketCount := calculateBucketCount(collectionSize, fn.BundleSize)
-
-       buckets, err := fn.getBuckets(ctx, bucketCount)
+       fn.filter, err = decodeBSON[bson.M](fn.Filter)
        if err != nil {
                return err
        }
 
-       idFilters := idFiltersFromBuckets(buckets)
-
-       for _, filter := range idFilters {
-               emit(filter)
-       }
+       fn.projection = inferProjection(fn.Type.T, bsonTag)
 
        return nil
 }
 
-type collStats struct {
-       Size int64 `bson:"size"`
-}
-
-func (fn *bucketAutoFn) getCollectionSize(ctx context.Context) (int64, error) {
-       cmd := bson.M{"collStats": fn.Collection}
-       opts := options.RunCmd().SetReadPreference(readpref.Primary())
-
-       var stats collStats
-       if err := fn.collection.Database().RunCommand(ctx, cmd, 
opts).Decode(&stats); err != nil {
-               return 0, fmt.Errorf("error executing collStats command: %w", 
err)
-       }
-
-       return stats.Size, nil
-}
-
-func calculateBucketCount(collectionSize int64, bundleSize int64) int32 {
-       if bundleSize < 0 {
-               panic("monogdbio.calculateBucketCount: bundle size must be 
greater than 0")
+func inferProjection(t reflect.Type, tagKey string) bson.D {
+       names := structx.InferFieldNames(t, tagKey)
+       if len(names) == 0 {
+               panic("mongodbio.inferProjection: no names to infer projection 
from")
        }
 
-       count := collectionSize / bundleSize
-       if collectionSize%bundleSize != 0 {
-               count++
-       }
+       projection := make(bson.D, len(names))
 
-       if count > int64(maxBucketCount) {
-               count = maxBucketCount
+       for i, name := range names {
+               projection[i] = bson.E{Key: name, Value: 1}
        }
 
-       return int32(count)
-}
-
-type bucket struct {
-       ID minMax `bson:"_id"`
-}
-
-type minMax struct {
-       Min any `bson:"min"`
-       Max any `bson:"max"`
+       return projection
 }
 
-func (fn *bucketAutoFn) getBuckets(ctx context.Context, count int32) 
([]bucket, error) {
-       pipeline := mongo.Pipeline{bson.D{{
-               Key: "$bucketAuto",
-               Value: bson.M{
-                       "groupBy": "$_id",
-                       "buckets": count,
-               },
-       }}}
-
-       opts := options.Aggregate().SetAllowDiskUse(true)
-
-       cursor, err := fn.collection.Aggregate(ctx, pipeline, opts)
-       if err != nil {
-               return nil, fmt.Errorf("error executing bucketAuto aggregation: 
%w", err)
+func (fn *readFn) CreateInitialRestriction(_ []byte) idRangeRestriction {
+       ctx := context.Background()
+       if err := fn.Setup(ctx); err != nil {
+               panic(err)
        }
 
-       var buckets []bucket
-       if err = cursor.All(ctx, &buckets); err != nil {
-               return nil, fmt.Errorf("error decoding buckets: %w", err)
-       }
-
-       return buckets, nil
-}
-
-func idFiltersFromBuckets(buckets []bucket) []bson.M {
-       idFilters := make([]bson.M, len(buckets))
-
-       for i := 0; i < len(buckets); i++ {
-               filter := bson.M{}
-
-               if i != 0 {
-                       filter["$gt"] = buckets[i].ID.Min
-               }
-
-               if i != len(buckets)-1 {
-                       filter["$lte"] = buckets[i].ID.Max
+       outerRange, err := findOuterIDRange(ctx, fn.collection, fn.filter)
+       if err != nil {
+               if errors.Is(err, mongo.ErrNoDocuments) {
+                       log.Infof(
+                               ctx,
+                               "No documents in collection %s.%s match the 
provided filter",
+                               fn.Database,
+                               fn.Collection,
+                       )
+                       return idRangeRestriction{}
                }
 
-               if len(filter) == 0 {
-                       idFilters[i] = filter
-               } else {
-                       idFilters[i] = bson.M{"_id": filter}
-               }
+               panic(err)
        }
 
-       return idFilters
-}
-
-type splitVectorFn struct {
-       mongoDBFn
-       BundleSize int64
-}
-
-func newSplitVectorFn(
-       uri string,
-       database string,
-       collection string,
-       option *ReadOption,
-) *splitVectorFn {
-       return &splitVectorFn{
-               mongoDBFn: mongoDBFn{
-                       URI:        uri,
-                       Database:   database,
-                       Collection: collection,
-               },
-               BundleSize: option.BundleSize,
-       }
+       return newIDRangeRestriction(
+               ctx,
+               fn.collection,
+               outerRange,
+               fn.filter,
+       )
 }
 
-func (fn *splitVectorFn) ProcessElement(
+func findOuterIDRange(
        ctx context.Context,
-       _ []byte,
-       emit func(bson.M),
-) error {
-       chunkSize := getChunkSize(fn.BundleSize)
-
-       splitKeys, err := fn.getSplitKeys(ctx, chunkSize)
+       collection *mongo.Collection,
+       filter bson.M,
+) (idRange, error) {
+       minID, err := findID(ctx, collection, filter, 1, 0)
        if err != nil {
-               return err
+               return idRange{}, err
        }
 
-       idFilters := idFiltersFromSplits(splitKeys)
-
-       for _, filter := range idFilters {
-               emit(filter)
+       maxID, err := findID(ctx, collection, filter, -1, 0)
+       if err != nil {
+               return idRange{}, err
        }
 
-       return nil
-}
-
-func getChunkSize(bundleSize int64) int64 {
-       var chunkSize int64
-
-       if bundleSize < minSplitVectorChunkSize {
-               chunkSize = minSplitVectorChunkSize
-       } else if bundleSize > maxSplitVectorChunkSize {
-               chunkSize = maxSplitVectorChunkSize
-       } else {
-               chunkSize = bundleSize
+       outerRange := idRange{
+               Min:          minID,
+               MinInclusive: true,
+               Max:          maxID,
+               MaxInclusive: true,
        }
 
-       return chunkSize
-}
-
-type splitVector struct {
-       SplitKeys []splitKey `bson:"splitKeys"`
-}
-
-type splitKey struct {
-       ID any `bson:"_id"`
+       return outerRange, nil
 }
 
-func (fn *splitVectorFn) getSplitKeys(ctx context.Context, chunkSize int64) 
([]splitKey, error) {
-       cmd := bson.D{
-               {Key: "splitVector", Value: fmt.Sprintf("%s.%s", fn.Database, 
fn.Collection)},
-               {Key: "keyPattern", Value: bson.D{{Key: "_id", Value: 1}}},
-               {Key: "maxChunkSizeBytes", Value: chunkSize},
+func (fn *readFn) SplitRestriction(_ []byte, rest idRangeRestriction) 
[]idRangeRestriction {
+       if rest.Count == 0 {
+               return []idRangeRestriction{rest}
        }
 
-       opts := options.RunCmd().SetReadPreference(readpref.Primary())
-
-       var vector splitVector
-       if err := fn.collection.Database().RunCommand(ctx, cmd, 
opts).Decode(&vector); err != nil {
-               return nil, fmt.Errorf("error executing splitVector command: 
%w", err)
+       ctx := context.Background()
+       if err := fn.Setup(ctx); err != nil {
+               panic(err)
        }
 
-       return vector.SplitKeys, nil
-}
-
-func idFiltersFromSplits(splitKeys []splitKey) []bson.M {
-       idFilters := make([]bson.M, len(splitKeys)+1)
-
-       for i := 0; i < len(splitKeys)+1; i++ {
-               filter := bson.M{}
-
-               if i > 0 {
-                       filter["$gt"] = splitKeys[i-1].ID
-               }
-
-               if i < len(splitKeys) {
-                       filter["$lte"] = splitKeys[i].ID
-               }
-
-               if len(filter) == 0 {
-                       idFilters[i] = filter
-               } else {
-                       idFilters[i] = bson.M{"_id": filter}
-               }
-       }
-
-       return idFilters
-}
-
-type readFn struct {
-       mongoDBFn
-       Filter     []byte
-       Type       beam.EncodedType
-       projection bson.D
-       filter     bson.M
-}
-
-func newReadFn(
-       uri string,
-       database string,
-       collection string,
-       t reflect.Type,
-       option *ReadOption,
-) *readFn {
-       filter, err := encodeBSONMap(option.Filter)
+       splits, err := rest.SizedSplits(ctx, fn.collection, fn.BundleSize, 
fn.BucketAuto)
        if err != nil {
-               panic(fmt.Sprintf("mongodbio.newReadFn: %v", err))
+               panic(err)
        }
 
-       return &readFn{
-               mongoDBFn: mongoDBFn{
-                       URI:        uri,
-                       Database:   database,
-                       Collection: collection,
-               },
-               Filter: filter,
-               Type:   beam.EncodedType{T: t},
-       }
+       return splits
 }
 
-func (fn *readFn) Setup(ctx context.Context) error {
-       if err := fn.mongoDBFn.Setup(ctx); err != nil {
-               return err
-       }
-
-       filter, err := decodeBSONMap(fn.Filter)
-       if err != nil {
-               return err
-       }
-
-       fn.filter = filter
-       fn.projection = inferProjection(fn.Type.T, bsonTag)
-
-       return nil
+func (fn *readFn) CreateTracker(rest idRangeRestriction) *sdf.LockRTracker {
+       return sdf.NewLockRTracker(newIDRangeTracker(rest, fn.collection))
 }
 
-func inferProjection(t reflect.Type, tagKey string) bson.D {
-       names := structx.InferFieldNames(t, tagKey)
-       if len(names) == 0 {
-               panic("mongodbio.inferProjection: no names to infer projection 
from")
-       }
-
-       projection := make(bson.D, len(names))
-
-       for i, name := range names {
-               projection[i] = bson.E{Key: name, Value: 1}
-       }
-
-       return projection
+func (fn *readFn) RestrictionSize(_ []byte, rest idRangeRestriction) float64 {
+       return float64(rest.Count)
 }
 
 func (fn *readFn) ProcessElement(
        ctx context.Context,
-       elem bson.M,
+       rt *sdf.LockRTracker,
+       _ []byte,
        emit func(beam.Y),
 ) (err error) {
-       mergedFilter := mergeFilters(elem, fn.filter)
+       rest := rt.GetRestriction().(idRangeRestriction)
 
-       cursor, err := fn.findDocuments(ctx, fn.projection, mergedFilter)
+       cursor, err := fn.getCursor(ctx, rest.Filter())
        if err != nil {
                return err
        }
@@ -442,53 +266,53 @@ func (fn *readFn) ProcessElement(
        }()
 
        for cursor.Next(ctx) {
-               value, err := decodeDocument(cursor, fn.Type.T)
+               id, value, err := decodeDocument(cursor, fn.Type.T)
                if err != nil {
                        return err
                }
 
-               emit(value)
-       }
-
-       return cursor.Err()
-}
+               result := cursorResult{nextID: id}
+               if !rt.TryClaim(result) {
+                       return cursor.Err()
+               }
 
-func mergeFilters(idFilter bson.M, customFilter bson.M) bson.M {
-       if len(idFilter) == 0 {
-               return customFilter
+               emit(value)
        }
 
-       if len(customFilter) == 0 {
-               return idFilter
-       }
+       result := cursorResult{isExhausted: true}
+       rt.TryClaim(result)
 
-       return bson.M{
-               "$and": []bson.M{idFilter, customFilter},
-       }
+       return cursor.Err()
 }
 
-func (fn *readFn) findDocuments(
+func (fn *readFn) getCursor(
        ctx context.Context,
-       projection bson.D,
        filter bson.M,
 ) (*mongo.Cursor, error) {
-       opts := options.Find().SetProjection(projection)
+       opts := options.Find().
+               SetProjection(fn.projection).
+               SetSort(bson.M{"_id": 1})
 
        cursor, err := fn.collection.Find(ctx, filter, opts)
        if err != nil {
-               return nil, fmt.Errorf("error finding documents: %w", err)
+               return nil, fmt.Errorf("error executing find command: %w", err)
        }
 
        return cursor, nil
 }
 
-func decodeDocument(cursor *mongo.Cursor, t reflect.Type) (any, error) {
+func decodeDocument(cursor *mongo.Cursor, t reflect.Type) (id any, value any, 
err error) {
+       var docID documentID
+       if err := cursor.Decode(&docID); err != nil {
+               return nil, nil, fmt.Errorf("error decoding document ID: %w", 
err)
+       }
+
        out := reflect.New(t).Interface()
        if err := cursor.Decode(out); err != nil {
-               return nil, fmt.Errorf("error decoding document: %w", err)
+               return nil, nil, fmt.Errorf("error decoding document: %w", err)
        }
 
-       value := reflect.ValueOf(out).Elem().Interface()
+       value = reflect.ValueOf(out).Elem().Interface()
 
-       return value, nil
+       return docID.ID, value, nil
 }
diff --git a/sdks/go/pkg/beam/io/mongodbio/read_test.go 
b/sdks/go/pkg/beam/io/mongodbio/read_test.go
index 5899457d5a8..666960b17b9 100644
--- a/sdks/go/pkg/beam/io/mongodbio/read_test.go
+++ b/sdks/go/pkg/beam/io/mongodbio/read_test.go
@@ -16,7 +16,6 @@
 package mongodbio
 
 import (
-       "math"
        "reflect"
        "testing"
 
@@ -24,250 +23,6 @@ import (
        "go.mongodb.org/mongo-driver/bson"
 )
 
-func Test_calculateBucketCount(t *testing.T) {
-       tests := []struct {
-               name           string
-               collectionSize int64
-               bundleSize     int64
-               want           int32
-       }{
-               {
-                       name:           "Return ceiling of collection size / 
bundle size",
-                       collectionSize: 3 * 1024 * 1024,
-                       bundleSize:     2 * 1024 * 1024,
-                       want:           2,
-               },
-               {
-                       name:           "Return max int32 when calculated count 
is greater than max int32",
-                       collectionSize: 1024 * 1024 * 1024 * 1024,
-                       bundleSize:     1,
-                       want:           math.MaxInt32,
-               },
-       }
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       if got := calculateBucketCount(tt.collectionSize, 
tt.bundleSize); got != tt.want {
-                               t.Errorf("calculateBucketCount() = %v, want 
%v", got, tt.want)
-                       }
-               })
-       }
-}
-
-func Test_calculateBucketCountPanic(t *testing.T) {
-       t.Run("Panic when bundleSize is not greater than 0", func(t *testing.T) 
{
-               defer func() {
-                       if r := recover(); r == nil {
-                               t.Errorf("calculateBucketCount() does not 
panic")
-                       }
-               }()
-
-               calculateBucketCount(1024, 0)
-       })
-}
-
-func Test_idFiltersFromBuckets(t *testing.T) {
-       tests := []struct {
-               name    string
-               buckets []bucket
-               want    []bson.M
-       }{
-               {
-                       name: "Create one $lte filter for start range, one $gt 
filter for end range, and filters with both " +
-                               "$lte and $gt for ranges in between when there 
are three or more bucket elements",
-                       buckets: []bucket{
-                               {
-                                       ID: minMax{
-                                               Min: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5378"),
-                                               Max: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                       },
-                               },
-                               {
-                                       ID: minMax{
-                                               Min: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                               Max: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5382"),
-                                       },
-                               },
-                               {
-                                       ID: minMax{
-                                               Min: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5382"),
-                                               Max: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5384"),
-                                       },
-                               },
-                       },
-                       want: []bson.M{
-                               {
-                                       "_id": bson.M{
-                                               "$lte": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                       },
-                               },
-                               {
-                                       "_id": bson.M{
-                                               "$gt":  objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                               "$lte": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5382"),
-                                       },
-                               },
-                               {
-                                       "_id": bson.M{
-                                               "$gt": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5382"),
-                                       },
-                               },
-                       },
-               },
-               {
-                       name: "Create one $lte filter for start range and one 
$gt filter for end range when there are two " +
-                               "bucket elements",
-                       buckets: []bucket{
-                               {
-                                       ID: minMax{
-                                               Min: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5378"),
-                                               Max: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                       },
-                               },
-                               {
-                                       ID: minMax{
-                                               Min: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                               Max: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5382"),
-                                       },
-                               },
-                       },
-                       want: []bson.M{
-                               {
-                                       "_id": bson.M{
-                                               "$lte": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                       },
-                               },
-                               {
-                                       "_id": bson.M{
-                                               "$gt": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                       },
-                               },
-                       },
-               },
-               {
-                       name: "Create an empty filter when there is one bucket 
element",
-                       buckets: []bucket{
-                               {
-                                       ID: minMax{
-                                               Min: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5378"),
-                                               Max: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                       },
-                               },
-                       },
-                       want: []bson.M{{}},
-               },
-       }
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       if got := idFiltersFromBuckets(tt.buckets); 
!cmp.Equal(got, tt.want) {
-                               t.Errorf("idFiltersFromBuckets() = %v, want 
%v", got, tt.want)
-                       }
-               })
-       }
-}
-
-func Test_getChunkSize(t *testing.T) {
-       tests := []struct {
-               name       string
-               bundleSize int64
-               want       int64
-       }{
-               {
-                       name:       "Return 1 MB if bundle size is less than 1 
MB",
-                       bundleSize: 1024,
-                       want:       1024 * 1024,
-               },
-               {
-                       name:       "Return 1 GB if bundle size is greater than 
1 GB",
-                       bundleSize: 2 * 1024 * 1024 * 1024,
-                       want:       1024 * 1024 * 1024,
-               },
-               {
-                       name:       "Return bundle size if bundle size is 
between 1 MB and 1 GB",
-                       bundleSize: 4 * 1024 * 1024,
-                       want:       4 * 1024 * 1024,
-               },
-       }
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       if got := getChunkSize(tt.bundleSize); got != tt.want {
-                               t.Errorf("getChunkSize() = %v, want %v", got, 
tt.want)
-                       }
-               })
-       }
-}
-
-func Test_idFiltersFromSplits(t *testing.T) {
-       tests := []struct {
-               name      string
-               splitKeys []splitKey
-               want      []bson.M
-       }{
-               {
-                       name: "Create one $lte filter for start range, one $gt 
filter for end range, and filters with both " +
-                               "$lte and $gt for ranges in between when there 
are two or more splitKey elements",
-                       splitKeys: []splitKey{
-                               {
-                                       ID: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                               },
-                               {
-                                       ID: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5382"),
-                               },
-                       },
-                       want: []bson.M{
-                               {
-                                       "_id": bson.M{
-                                               "$lte": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                       },
-                               },
-                               {
-                                       "_id": bson.M{
-                                               "$gt":  objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                               "$lte": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5382"),
-                                       },
-                               },
-                               {
-                                       "_id": bson.M{
-                                               "$gt": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5382"),
-                                       },
-                               },
-                       },
-               },
-               {
-                       name: "Create one $lte filter for start range and one 
$gt filter for end range when there is one " +
-                               "splitKey element",
-                       splitKeys: []splitKey{
-                               {
-                                       ID: objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                               },
-                       },
-                       want: []bson.M{
-                               {
-                                       "_id": bson.M{
-                                               "$lte": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                       },
-                               },
-                               {
-                                       "_id": bson.M{
-                                               "$gt": objectIDFromHex(t, 
"6384e03f24f854c1a8ce5380"),
-                                       },
-                               },
-                       },
-               },
-               {
-                       name:      "Create an empty filter when there are no 
splitKey elements",
-                       splitKeys: nil,
-                       want:      []bson.M{{}},
-               },
-       }
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       if got := idFiltersFromSplits(tt.splitKeys); 
!cmp.Equal(got, tt.want) {
-                               t.Errorf("idFiltersFromSplits() = %v, want %v", 
got, tt.want)
-                       }
-               })
-       }
-}
-
 func Test_inferProjection(t *testing.T) {
        type doc struct {
                Field1 string `bson:"field1"`
@@ -313,81 +68,3 @@ func Test_inferProjectionPanic(t *testing.T) {
                inferProjection(reflect.TypeOf(doc{}), "bson")
        })
 }
-
-func Test_mergeFilters(t *testing.T) {
-       tests := []struct {
-               name     string
-               idFilter bson.M
-               filter   bson.M
-               want     bson.M
-       }{
-               {
-                       name: "Returned merged ID filter and custom filter in 
an $and filter",
-                       idFilter: bson.M{
-                               "_id": bson.M{
-                                       "$gte": 10,
-                               },
-                       },
-                       filter: bson.M{
-                               "key": bson.M{
-                                       "$ne": "value",
-                               },
-                       },
-                       want: bson.M{
-                               "$and": []bson.M{
-                                       {
-                                               "_id": bson.M{
-                                                       "$gte": 10,
-                                               },
-                                       },
-                                       {
-                                               "key": bson.M{
-                                                       "$ne": "value",
-                                               },
-                                       },
-                               },
-                       },
-               },
-               {
-                       name: "Return only ID filter when custom filter is 
empty",
-                       idFilter: bson.M{
-                               "_id": bson.M{
-                                       "$gte": 10,
-                               },
-                       },
-                       filter: bson.M{},
-                       want: bson.M{
-                               "_id": bson.M{
-                                       "$gte": 10,
-                               },
-                       },
-               },
-               {
-                       name:     "Return only custom filter when ID filter is 
empty",
-                       idFilter: bson.M{},
-                       filter: bson.M{
-                               "key": bson.M{
-                                       "$ne": "value",
-                               },
-                       },
-                       want: bson.M{
-                               "key": bson.M{
-                                       "$ne": "value",
-                               },
-                       },
-               },
-               {
-                       name:     "Return empty filter when both ID filter and 
custom filter are empty",
-                       idFilter: bson.M{},
-                       filter:   bson.M{},
-                       want:     bson.M{},
-               },
-       }
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       if got := mergeFilters(tt.idFilter, tt.filter); 
!cmp.Equal(got, tt.want) {
-                               t.Errorf("mergeFilters() = %v, want %v", got, 
tt.want)
-                       }
-               })
-       }
-}

Reply via email to