Copilot commented on code in PR #735:
URL: 
https://github.com/apache/skywalking-banyandb/pull/735#discussion_r2298596928


##########
banyand/trace/query.go:
##########
@@ -18,39 +18,298 @@
 package trace
 
 import (
-       "github.com/apache/skywalking-banyandb/api/common"
+       "container/heap"
+       "context"
+       "fmt"
+       "sort"
+
+       "github.com/pkg/errors"
+
        databasev1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
        modelv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
+       "github.com/apache/skywalking-banyandb/banyand/internal/storage"
        "github.com/apache/skywalking-banyandb/pkg/convert"
        "github.com/apache/skywalking-banyandb/pkg/logger"
        pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
        "github.com/apache/skywalking-banyandb/pkg/query/model"
 )
 
+const checkDoneEvery = 128
+
+var nilResult = model.TraceQueryResult(nil)
+
 type queryOptions struct {
+       traceID string
        model.TraceQueryOptions
-       seriesToEntity map[common.SeriesID][]*modelv1.TagValue
-       sortedSids     []common.SeriesID
-       minTimestamp   int64
-       maxTimestamp   int64
+       minTimestamp int64
+       maxTimestamp int64
 }
 
 func (qo *queryOptions) reset() {
        qo.TraceQueryOptions.Reset()
-       qo.seriesToEntity = nil
-       qo.sortedSids = nil
+       qo.traceID = ""
        qo.minTimestamp = 0
        qo.maxTimestamp = 0
 }
 
 func (qo *queryOptions) copyFrom(other *queryOptions) {
        qo.TraceQueryOptions.CopyFrom(&other.TraceQueryOptions)
-       qo.seriesToEntity = other.seriesToEntity
-       qo.sortedSids = other.sortedSids
+       qo.traceID = other.traceID
        qo.minTimestamp = other.minTimestamp
        qo.maxTimestamp = other.maxTimestamp
 }
 
+func (t *trace) Query(ctx context.Context, tqo model.TraceQueryOptions) 
(model.TraceQueryResult, error) {
+       if tqo.TimeRange == nil {
+               return nil, errors.New("invalid query options: timeRange are 
required")
+       }
+       if tqo.TagProjection == nil || len(tqo.TagProjection.Names) == 0 {
+               return nil, errors.New("invalid query options: tagProjection is 
required")
+       }
+       var tsdb storage.TSDB[*tsTable, option]
+       var err error
+       db := t.tsdb.Load()
+       if db == nil {
+               tsdb, err = t.schemaRepo.loadTSDB(t.group)
+               if err != nil {
+                       return nil, err
+               }
+               t.tsdb.Store(tsdb)
+       } else {
+               tsdb = db.(storage.TSDB[*tsTable, option])
+       }
+
+       segments, err := tsdb.SelectSegments(*tqo.TimeRange)
+       if err != nil {
+               return nil, err
+       }
+       if len(segments) < 1 {
+               return nilResult, nil
+       }
+
+       result := queryResult{
+               ctx:           ctx,
+               segments:      segments,
+               tagProjection: tqo.TagProjection,
+       }
+       defer func() {
+               if err != nil {
+                       result.Release()
+               }
+       }()
+       var parts []*part
+       qo := queryOptions{
+               TraceQueryOptions: tqo,
+               traceID:           "",
+               minTimestamp:      tqo.TimeRange.Start.UnixNano(),
+               maxTimestamp:      tqo.TimeRange.End.UnixNano(),
+       }
+       var n int
+       tables := make([]*tsTable, 0)
+       for _, segment := range segments {
+               tt, _ := segment.Tables()
+               tables = append(tables, tt...)
+       }
+       for i := range tables {
+               s := tables[i].currentSnapshot()
+               if s == nil {
+                       continue
+               }
+               parts, n = s.getParts(parts, qo.minTimestamp, qo.maxTimestamp)
+               if n < 1 {
+                       s.decRef()
+                       continue
+               }
+               result.snapshots = append(result.snapshots, s)
+       }
+
+       if err = t.searchBlocks(ctx, &result, parts, qo); err != nil {
+               return nil, err
+       }
+
+       return &result, nil
+}
+
+func (t *trace) searchBlocks(ctx context.Context, result *queryResult, parts 
[]*part, qo queryOptions) error {
+       bma := generateBlockMetadataArray()
+       defer releaseBlockMetadataArray(bma)
+       defFn := startBlockScanSpan(ctx, qo.traceID, parts, result)
+       defer defFn()
+       tstIter := generateTstIter()
+       defer releaseTstIter(tstIter)
+       tstIter.init(bma, parts, qo.traceID)
+       if tstIter.Error() != nil {
+               return fmt.Errorf("cannot init tstIter: %w", tstIter.Error())
+       }
+       var hit int
+       var spanBlockBytes uint64
+       quota := t.pm.AvailableBytes()
+       for tstIter.nextBlock() {
+               if hit%checkDoneEvery == 0 {
+                       select {
+                       case <-ctx.Done():
+                               return errors.WithMessagef(ctx.Err(), 
"interrupt: scanned %d blocks, remained %d/%d parts to scan",
+                                       len(result.data), 
len(tstIter.piPool)-tstIter.idx, len(tstIter.piPool))
+                       default:
+                       }
+               }
+               hit++
+               bc := generateBlockCursor()
+               p := tstIter.piPool[tstIter.idx]
+               bc.init(p.p, p.curBlock, qo)
+               result.data = append(result.data, bc)
+               spanBlockBytes += bc.bm.uncompressedSpanSizeBytes
+               if quota >= 0 && spanBlockBytes > uint64(quota) {
+                       return fmt.Errorf("block scan quota exceeded: used %d 
bytes, quota is %d bytes", spanBlockBytes, quota)
+               }
+       }
+       if tstIter.Error() != nil {
+               return fmt.Errorf("cannot iterate tstIter: %w", tstIter.Error())
+       }
+       return t.pm.AcquireResource(ctx, spanBlockBytes)
+}
+
+type queryResult struct {
+       ctx           context.Context
+       tagProjection *model.TagProjection
+       data          []*blockCursor
+       snapshots     []*snapshot
+       segments      []storage.Segment[*tsTable, option]
+       hit           int
+       loaded        bool
+}
+
+func (qr *queryResult) Pull() *model.TraceResult {
+       select {
+       case <-qr.ctx.Done():
+               return &model.TraceResult{
+                       Error: errors.WithMessagef(qr.ctx.Err(), "interrupt: 
hit %d", qr.hit),
+               }
+       default:
+       }
+       if !qr.loaded {
+               if len(qr.data) == 0 {
+                       return nil
+               }
+
+               cursorChan := make(chan int, len(qr.data))
+               for i := 0; i < len(qr.data); i++ {
+                       go func(i int) {
+                               select {
+                               case <-qr.ctx.Done():
+                                       cursorChan <- i
+                                       return
+                               default:
+                               }
+                               tmpBlock := generateBlock()
+                               defer releaseBlock(tmpBlock)
+                               if !qr.data[i].loadData(tmpBlock) {
+                                       cursorChan <- i
+                                       return
+                               }
+                               cursorChan <- -1
+                       }(i)
+               }
+
+               blankCursorList := []int{}
+               for completed := 0; completed < len(qr.data); completed++ {
+                       result := <-cursorChan
+                       if result != -1 {
+                               blankCursorList = append(blankCursorList, 
result)
+                       }
+               }
+               select {
+               case <-qr.ctx.Done():
+                       return &model.TraceResult{
+                               Error: errors.WithMessagef(qr.ctx.Err(), 
"interrupt: blank/total=%d/%d", len(blankCursorList), len(qr.data)),
+                       }
+               default:
+               }
+               sort.Slice(blankCursorList, func(i, j int) bool {
+                       return blankCursorList[i] > blankCursorList[j]
+               })
+               for _, index := range blankCursorList {
+                       qr.data = append(qr.data[:index], qr.data[index+1:]...)
+               }
+               qr.loaded = true
+               heap.Init(qr)
+       }
+       if len(qr.data) == 0 {
+               return nil
+       }
+       if len(qr.data) == 1 {
+               r := &model.TraceResult{}
+               bc := qr.data[0]
+               bc.copyAllTo(r, false)
+               qr.data = qr.data[:0]
+               releaseBlockCursor(bc)
+               return r
+       }
+       return qr.merge()
+}
+
+func (qr *queryResult) Release() {
+       for i, v := range qr.data {
+               releaseBlockCursor(v)
+               qr.data[i] = nil
+       }
+       qr.data = qr.data[:0]
+       for i := range qr.snapshots {
+               qr.snapshots[i].decRef()
+       }
+       qr.snapshots = qr.snapshots[:0]
+       for i := range qr.segments {
+               qr.segments[i].DecRef()
+       }
+}
+
+func (qr queryResult) Len() int {
+       return len(qr.data)
+}
+
+func (qr queryResult) Less(i, j int) bool {
+       return qr.data[i].bm.traceID < qr.data[j].bm.traceID
+}
+
+func (qr queryResult) Swap(i, j int) {
+       qr.data[i], qr.data[j] = qr.data[j], qr.data[i]
+}
+
+func (qr *queryResult) Push(x interface{}) {
+       qr.data = append(qr.data, x.(*blockCursor))

Review Comment:
   [nitpick] The Push and Pop methods use interface{} instead of type-safe 
generics. Consider using a type parameter or at least add type assertions with 
error handling to make the code more robust.
   ```suggestion
        bc, ok := x.(*blockCursor)
        if !ok {
                panic(fmt.Sprintf("queryResult.Push: expected *blockCursor, got 
%T", x))
        }
        qr.data = append(qr.data, bc)
   ```



##########
banyand/trace/part_iter.go:
##########
@@ -193,24 +154,23 @@ func (pi *partIter) readPrimaryBlock(bms []blockMetadata, 
mr *primaryBlockMetada
 
 func (pi *partIter) findBlock() bool {
        bhs := pi.bms
-       for len(bhs) > 0 {
+       if len(bhs) > 0 {
                tid := pi.curBlock.traceID
                if bhs[0].traceID < tid {
                        n := sort.Search(len(bhs), func(i int) bool {
                                return tid <= bhs[i].traceID
                        })
                        if n == len(bhs) {
-                               break
+                               pi.bms = nil
+                               return false
                        }
                        bhs = bhs[n:]
                }
                bm := &bhs[0]
 
-               if bm.traceID != tid {
-                       if !pi.searchTargetTraceID(bm.traceID) {
-                               return false
-                       }
-                       continue
+               if bm.traceID > tid {
+                       pi.bms = bhs[:0]

Review Comment:
   Setting `pi.bms = bhs[:0]` creates a slice with zero length but retains the 
underlying array capacity. This should be `pi.bms = nil` to properly clear the 
reference and allow garbage collection.
   ```suggestion
                        pi.bms = nil
   ```



##########
banyand/trace/query.go:
##########
@@ -18,39 +18,298 @@
 package trace
 
 import (
-       "github.com/apache/skywalking-banyandb/api/common"
+       "container/heap"
+       "context"
+       "fmt"
+       "sort"
+
+       "github.com/pkg/errors"
+
        databasev1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
        modelv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
+       "github.com/apache/skywalking-banyandb/banyand/internal/storage"
        "github.com/apache/skywalking-banyandb/pkg/convert"
        "github.com/apache/skywalking-banyandb/pkg/logger"
        pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
        "github.com/apache/skywalking-banyandb/pkg/query/model"
 )
 
+const checkDoneEvery = 128
+
+var nilResult = model.TraceQueryResult(nil)
+
 type queryOptions struct {
+       traceID string
        model.TraceQueryOptions
-       seriesToEntity map[common.SeriesID][]*modelv1.TagValue
-       sortedSids     []common.SeriesID
-       minTimestamp   int64
-       maxTimestamp   int64
+       minTimestamp int64
+       maxTimestamp int64
 }
 
 func (qo *queryOptions) reset() {
        qo.TraceQueryOptions.Reset()
-       qo.seriesToEntity = nil
-       qo.sortedSids = nil
+       qo.traceID = ""
        qo.minTimestamp = 0
        qo.maxTimestamp = 0
 }
 
 func (qo *queryOptions) copyFrom(other *queryOptions) {
        qo.TraceQueryOptions.CopyFrom(&other.TraceQueryOptions)
-       qo.seriesToEntity = other.seriesToEntity
-       qo.sortedSids = other.sortedSids
+       qo.traceID = other.traceID
        qo.minTimestamp = other.minTimestamp
        qo.maxTimestamp = other.maxTimestamp
 }
 
+func (t *trace) Query(ctx context.Context, tqo model.TraceQueryOptions) 
(model.TraceQueryResult, error) {
+       if tqo.TimeRange == nil {
+               return nil, errors.New("invalid query options: timeRange are 
required")
+       }
+       if tqo.TagProjection == nil || len(tqo.TagProjection.Names) == 0 {
+               return nil, errors.New("invalid query options: tagProjection is 
required")
+       }
+       var tsdb storage.TSDB[*tsTable, option]
+       var err error
+       db := t.tsdb.Load()
+       if db == nil {
+               tsdb, err = t.schemaRepo.loadTSDB(t.group)
+               if err != nil {
+                       return nil, err
+               }
+               t.tsdb.Store(tsdb)
+       } else {
+               tsdb = db.(storage.TSDB[*tsTable, option])
+       }
+
+       segments, err := tsdb.SelectSegments(*tqo.TimeRange)
+       if err != nil {
+               return nil, err
+       }
+       if len(segments) < 1 {
+               return nilResult, nil
+       }
+
+       result := queryResult{
+               ctx:           ctx,
+               segments:      segments,
+               tagProjection: tqo.TagProjection,
+       }
+       defer func() {
+               if err != nil {
+                       result.Release()
+               }
+       }()
+       var parts []*part
+       qo := queryOptions{
+               TraceQueryOptions: tqo,
+               traceID:           "",
+               minTimestamp:      tqo.TimeRange.Start.UnixNano(),
+               maxTimestamp:      tqo.TimeRange.End.UnixNano(),
+       }
+       var n int
+       tables := make([]*tsTable, 0)
+       for _, segment := range segments {
+               tt, _ := segment.Tables()
+               tables = append(tables, tt...)
+       }
+       for i := range tables {
+               s := tables[i].currentSnapshot()
+               if s == nil {
+                       continue
+               }
+               parts, n = s.getParts(parts, qo.minTimestamp, qo.maxTimestamp)
+               if n < 1 {
+                       s.decRef()
+                       continue
+               }
+               result.snapshots = append(result.snapshots, s)
+       }
+
+       if err = t.searchBlocks(ctx, &result, parts, qo); err != nil {
+               return nil, err
+       }
+
+       return &result, nil
+}
+
+func (t *trace) searchBlocks(ctx context.Context, result *queryResult, parts 
[]*part, qo queryOptions) error {
+       bma := generateBlockMetadataArray()
+       defer releaseBlockMetadataArray(bma)
+       defFn := startBlockScanSpan(ctx, qo.traceID, parts, result)
+       defer defFn()
+       tstIter := generateTstIter()
+       defer releaseTstIter(tstIter)
+       tstIter.init(bma, parts, qo.traceID)
+       if tstIter.Error() != nil {
+               return fmt.Errorf("cannot init tstIter: %w", tstIter.Error())
+       }
+       var hit int
+       var spanBlockBytes uint64
+       quota := t.pm.AvailableBytes()
+       for tstIter.nextBlock() {
+               if hit%checkDoneEvery == 0 {
+                       select {
+                       case <-ctx.Done():
+                               return errors.WithMessagef(ctx.Err(), 
"interrupt: scanned %d blocks, remained %d/%d parts to scan",
+                                       len(result.data), 
len(tstIter.piPool)-tstIter.idx, len(tstIter.piPool))
+                       default:
+                       }
+               }
+               hit++
+               bc := generateBlockCursor()
+               p := tstIter.piPool[tstIter.idx]
+               bc.init(p.p, p.curBlock, qo)
+               result.data = append(result.data, bc)
+               spanBlockBytes += bc.bm.uncompressedSpanSizeBytes
+               if quota >= 0 && spanBlockBytes > uint64(quota) {
+                       return fmt.Errorf("block scan quota exceeded: used %d 
bytes, quota is %d bytes", spanBlockBytes, quota)
+               }
+       }
+       if tstIter.Error() != nil {
+               return fmt.Errorf("cannot iterate tstIter: %w", tstIter.Error())
+       }
+       return t.pm.AcquireResource(ctx, spanBlockBytes)
+}
+
+type queryResult struct {
+       ctx           context.Context
+       tagProjection *model.TagProjection
+       data          []*blockCursor
+       snapshots     []*snapshot
+       segments      []storage.Segment[*tsTable, option]
+       hit           int
+       loaded        bool
+}
+
+func (qr *queryResult) Pull() *model.TraceResult {
+       select {
+       case <-qr.ctx.Done():
+               return &model.TraceResult{
+                       Error: errors.WithMessagef(qr.ctx.Err(), "interrupt: 
hit %d", qr.hit),
+               }
+       default:
+       }
+       if !qr.loaded {
+               if len(qr.data) == 0 {
+                       return nil
+               }
+
+               cursorChan := make(chan int, len(qr.data))
+               for i := 0; i < len(qr.data); i++ {
+                       go func(i int) {
+                               select {
+                               case <-qr.ctx.Done():
+                                       cursorChan <- i
+                                       return
+                               default:
+                               }
+                               tmpBlock := generateBlock()
+                               defer releaseBlock(tmpBlock)
+                               if !qr.data[i].loadData(tmpBlock) {
+                                       cursorChan <- i
+                                       return
+                               }
+                               cursorChan <- -1
+                       }(i)
+               }
+
+               blankCursorList := []int{}
+               for completed := 0; completed < len(qr.data); completed++ {
+                       result := <-cursorChan
+                       if result != -1 {
+                               blankCursorList = append(blankCursorList, 
result)
+                       }
+               }
+               select {
+               case <-qr.ctx.Done():
+                       return &model.TraceResult{
+                               Error: errors.WithMessagef(qr.ctx.Err(), 
"interrupt: blank/total=%d/%d", len(blankCursorList), len(qr.data)),
+                       }
+               default:
+               }
+               sort.Slice(blankCursorList, func(i, j int) bool {
+                       return blankCursorList[i] > blankCursorList[j]
+               })
+               for _, index := range blankCursorList {
+                       qr.data = append(qr.data[:index], qr.data[index+1:]...)
+               }
+               qr.loaded = true
+               heap.Init(qr)
+       }
+       if len(qr.data) == 0 {
+               return nil
+       }
+       if len(qr.data) == 1 {
+               r := &model.TraceResult{}
+               bc := qr.data[0]
+               bc.copyAllTo(r, false)
+               qr.data = qr.data[:0]
+               releaseBlockCursor(bc)
+               return r
+       }
+       return qr.merge()
+}
+
+func (qr *queryResult) Release() {
+       for i, v := range qr.data {
+               releaseBlockCursor(v)
+               qr.data[i] = nil
+       }
+       qr.data = qr.data[:0]
+       for i := range qr.snapshots {
+               qr.snapshots[i].decRef()
+       }
+       qr.snapshots = qr.snapshots[:0]
+       for i := range qr.segments {
+               qr.segments[i].DecRef()
+       }
+}
+
+func (qr queryResult) Len() int {
+       return len(qr.data)
+}
+
+func (qr queryResult) Less(i, j int) bool {
+       return qr.data[i].bm.traceID < qr.data[j].bm.traceID
+}
+
+func (qr queryResult) Swap(i, j int) {
+       qr.data[i], qr.data[j] = qr.data[j], qr.data[i]
+}
+
+func (qr *queryResult) Push(x interface{}) {
+       qr.data = append(qr.data, x.(*blockCursor))
+}
+
+func (qr *queryResult) Pop() interface{} {
+       old := qr.data
+       n := len(old)
+       x := old[n-1]
+       qr.data = old[0 : n-1]
+       releaseBlockCursor(x)
+       return x
+}

Review Comment:
   [nitpick] The Push and Pop methods use interface{} instead of type-safe 
generics. Consider using a type parameter or at least add type assertions with 
error handling to make the code more robust.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to