This is an automated email from the ASF dual-hosted git repository.

hanahmily pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git


The following commit(s) were added to refs/heads/main by this push:
     new 675a46df push down aggregation for topN query (#697)
675a46df is described below

commit 675a46df417ed4be9e967c1e94b9f5f67d6d1b93
Author: hui lai <1353307...@qq.com>
AuthorDate: Thu Jul 10 09:31:39 2025 +0800

    push down aggregation for topN query (#697)
    
    Co-authored-by: Gao Hongtao <hanahm...@gmail.com>
    Co-authored-by: 吴晟 Wu Sheng <wu.sh...@foxmail.com>
---
 CHANGES.md                                         |   1 +
 api/proto/banyandb/measure/v1/query.proto          |   2 +
 banyand/query/processor.go                         | 150 +++++++++++++++++++++
 docs/api-reference.md                              |   1 +
 .../logical/measure/measure_plan_distributed.go    |   7 +
 5 files changed, 161 insertions(+)

diff --git a/CHANGES.md b/CHANGES.md
index 99804b8a..5cc0071b 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -20,6 +20,7 @@ Release Notes.
 - Add Load Balancer Feature to Liaison. 
 - Implement fadvise for large files to prevent page cache pollution.
 - Data Model: Introduce the `Trace` data model to store the trace/span data.
+- Push down aggregation for topN query.
 
 ### Bug Fixes
 
diff --git a/api/proto/banyandb/measure/v1/query.proto 
b/api/proto/banyandb/measure/v1/query.proto
index 05f8309b..9bd0f045 100644
--- a/api/proto/banyandb/measure/v1/query.proto
+++ b/api/proto/banyandb/measure/v1/query.proto
@@ -113,4 +113,6 @@ message QueryRequest {
   bool trace = 13;
   // stages is used to specify the stage of the data points in the lifecycle
   repeated string stages = 14;
+  // rewriteAggTopNResult will rewrite agg result to raw data
+  bool rewrite_agg_top_n_result = 15;
 }
diff --git a/banyand/query/processor.go b/banyand/query/processor.go
index 7610997a..203d2dd5 100644
--- a/banyand/query/processor.go
+++ b/banyand/query/processor.go
@@ -24,9 +24,12 @@ import (
        "runtime/debug"
        "time"
 
+       "golang.org/x/exp/slices"
+
        "github.com/apache/skywalking-banyandb/api/common"
        commonv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/common/v1"
        measurev1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/measure/v1"
+       modelv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
        streamv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/stream/v1"
        "github.com/apache/skywalking-banyandb/banyand/measure"
        "github.com/apache/skywalking-banyandb/banyand/stream"
@@ -160,6 +163,62 @@ func (p *measureQueryProcessor) Rev(ctx context.Context, 
message bus.Message) (r
                resp = bus.NewMessage(bus.MessageID(now), 
common.NewError("invalid event data type"))
                return
        }
+       if queryCriteria.RewriteAggTopNResult {
+               queryCriteria.Top.Number *= 2
+       }
+       resp = p.executeQuery(ctx, queryCriteria)
+
+       if queryCriteria.RewriteAggTopNResult {
+               result, handleErr := handleResponse(resp)
+               if handleErr != nil {
+                       return
+               }
+               if len(result) == 0 {
+                       return
+               }
+               groupByTags := make([]string, 0)
+               if queryCriteria.GetGroupBy() != nil {
+                       for _, tagFamily := range 
queryCriteria.GetGroupBy().GetTagProjection().GetTagFamilies() {
+                               groupByTags = append(groupByTags, 
tagFamily.GetTags()...)
+                       }
+               }
+               tagValueMap := make(map[string][]*modelv1.TagValue)
+               for _, dp := range result {
+                       for _, tagFamily := range dp.GetTagFamilies() {
+                               for _, tag := range tagFamily.GetTags() {
+                                       tagName := tag.GetKey()
+                                       if len(groupByTags) == 0 || 
slices.Contains(groupByTags, tagName) {
+                                               tagValueMap[tagName] = 
append(tagValueMap[tagName], tag.GetValue())
+                                       }
+                               }
+                       }
+               }
+               rewriteCriteria, err := rewriteCriteria(tagValueMap)
+               if err != nil {
+                       p.log.Error().Err(err).RawJSON("req", 
logger.Proto(queryCriteria)).Msg("fail to rewrite the query criteria")
+                       return
+               }
+               rewriteQueryCriteria := &measurev1.QueryRequest{
+                       Groups:          queryCriteria.Groups,
+                       Name:            queryCriteria.Name,
+                       TimeRange:       queryCriteria.TimeRange,
+                       Criteria:        rewriteCriteria,
+                       TagProjection:   queryCriteria.TagProjection,
+                       FieldProjection: queryCriteria.FieldProjection,
+               }
+               resp = p.executeQuery(ctx, rewriteQueryCriteria)
+               dataPoints, handleErr := handleResponse(resp)
+               if handleErr != nil {
+                       return
+               }
+               resp = bus.NewMessage(bus.MessageID(now), 
&measurev1.QueryResponse{DataPoints: dataPoints})
+       }
+       return
+}
+
+func (p *measureQueryProcessor) executeQuery(ctx context.Context, 
queryCriteria *measurev1.QueryRequest) (resp bus.Message) {
+       n := time.Now()
+       now := n.UnixNano()
        defer func() {
                if err := recover(); err != nil {
                        p.log.Error().Interface("err", err).RawJSON("req", 
logger.Proto(queryCriteria)).Str("stack", string(debug.Stack())).Msg("panic")
@@ -270,3 +329,94 @@ func (p *measureQueryProcessor) Rev(ctx context.Context, 
message bus.Message) (r
        }
        return
 }
+
+func handleResponse(resp bus.Message) ([]*measurev1.DataPoint, *common.Error) {
+       data := resp.Data()
+       switch d := data.(type) {
+       case *common.Error:
+               return nil, d
+       case *measurev1.QueryResponse:
+               return d.DataPoints, nil
+       default:
+               return nil, common.NewError("unexpected response data type: 
%T", d)
+       }
+}
+
+func rewriteCriteria(tagValueMap map[string][]*modelv1.TagValue) 
(*modelv1.Criteria, error) {
+       var tagConditions []*modelv1.Condition
+       for tagName, tagValues := range tagValueMap {
+               if len(tagValues) == 0 {
+                       continue
+               }
+               switch tagValues[0].GetValue().(type) {
+               case *modelv1.TagValue_Str:
+                       valueSet := make(map[string]bool)
+                       for _, value := range tagValues {
+                               if strVal, ok := 
value.GetValue().(*modelv1.TagValue_Str); ok {
+                                       valueSet[strVal.Str.GetValue()] = true
+                               }
+                       }
+                       values := make([]string, 0, len(valueSet))
+                       for value := range valueSet {
+                               values = append(values, value)
+                       }
+                       condition := &modelv1.Condition{
+                               Name: tagName,
+                               Op:   modelv1.Condition_BINARY_OP_IN,
+                               Value: &modelv1.TagValue{
+                                       Value: &modelv1.TagValue_StrArray{
+                                               StrArray: &modelv1.StrArray{
+                                                       Value: values,
+                                               },
+                                       },
+                               },
+                       }
+                       tagConditions = append(tagConditions, condition)
+               case *modelv1.TagValue_Int:
+                       valueSet := make(map[int64]bool)
+                       for _, value := range tagValues {
+                               if intVal, ok := 
value.GetValue().(*modelv1.TagValue_Int); ok {
+                                       valueSet[intVal.Int.GetValue()] = true
+                               }
+                       }
+                       values := make([]int64, 0, len(valueSet))
+                       for value := range valueSet {
+                               values = append(values, value)
+                       }
+                       condition := &modelv1.Condition{
+                               Name: tagName,
+                               Op:   modelv1.Condition_BINARY_OP_IN,
+                               Value: &modelv1.TagValue{
+                                       Value: &modelv1.TagValue_IntArray{
+                                               IntArray: &modelv1.IntArray{
+                                                       Value: values,
+                                               },
+                                       },
+                               },
+                       }
+                       tagConditions = append(tagConditions, condition)
+               default:
+                       return nil, fmt.Errorf("unsupported tag value type: 
%T", tagValues[0].GetValue())
+               }
+       }
+       return buildCriteriaTree(tagConditions), nil
+}
+
+func buildCriteriaTree(conditions []*modelv1.Condition) *modelv1.Criteria {
+       if len(conditions) == 0 {
+               return nil
+       }
+       return &modelv1.Criteria{
+               Exp: &modelv1.Criteria_Le{
+                       Le: &modelv1.LogicalExpression{
+                               Op: modelv1.LogicalExpression_LOGICAL_OP_AND,
+                               Left: &modelv1.Criteria{
+                                       Exp: &modelv1.Criteria_Condition{
+                                               Condition: conditions[0],
+                                       },
+                               },
+                               Right: buildCriteriaTree(conditions[1:]),
+                       },
+               },
+       }
+}
diff --git a/docs/api-reference.md b/docs/api-reference.md
index 4a5006d6..44f4654c 100644
--- a/docs/api-reference.md
+++ b/docs/api-reference.md
@@ -3003,6 +3003,7 @@ QueryRequest is the request contract for query.
 | order_by | [banyandb.model.v1.QueryOrder](#banyandb-model-v1-QueryOrder) |  
| order_by is given to specify the sort for a tag. |
 | trace | [bool](#bool) |  | trace is used to enable trace for the query |
 | stages | [string](#string) | repeated | stages is used to specify the stage 
of the data points in the lifecycle |
+| rewrite_agg_top_n_result | [bool](#bool) |  | rewriteAggTopNResult will 
rewrite agg result to raw data |
 
 
 
diff --git a/pkg/query/logical/measure/measure_plan_distributed.go 
b/pkg/query/logical/measure/measure_plan_distributed.go
index c6843029..e7dafb7a 100644
--- a/pkg/query/logical/measure/measure_plan_distributed.go
+++ b/pkg/query/logical/measure/measure_plan_distributed.go
@@ -90,6 +90,13 @@ func (ud *unresolvedDistributed) Analyze(s logical.Schema) 
(logical.Plan, error)
                Limit:           limit + ud.originalQuery.Offset,
                OrderBy:         ud.originalQuery.OrderBy,
        }
+       // push down groupBy, agg and top to data node and rewrite agg result 
to raw data
+       if ud.originalQuery.Agg != nil && ud.originalQuery.Top != nil {
+               temp.RewriteAggTopNResult = true
+               temp.Agg = ud.originalQuery.Agg
+               temp.Top = ud.originalQuery.Top
+               temp.GroupBy = ud.originalQuery.GroupBy
+       }
        if ud.groupByEntity {
                e := s.EntityList()[0]
                sortTagSpec := s.FindTagSpecByName(e)

Reply via email to