Copilot commented on code in PR #878: URL: https://github.com/apache/dubbo-go-pixiu/pull/878#discussion_r2776695050
########## pkg/filter/ai/kvcache/lmcache_client.go: ########## @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kvcache + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" +) + +import ( + "github.com/go-resty/resty/v2" +) + +type LMCacheClient struct { + httpClient *resty.Client + baseURL string + retry RetryConfig + circuitBreaker *CircuitBreaker +} + +type PinRequest struct { + Tokens []int `json:"tokens"` + InstanceID string `json:"instance_id"` + Location string `json:"location"` +} + +type LookupRequest struct { + Tokens []int `json:"tokens"` +} + +type CompressRequest struct { + Tokens []int `json:"tokens"` + InstanceID string `json:"instance_id"` + Location string `json:"location"` + Method string `json:"method"` +} + +type EvictRequest struct { + Tokens []int `json:"tokens"` + InstanceID string `json:"instance_id"` +} + +type PinResponse struct { + EventID string `json:"event_id"` + NumTokens int `json:"num_tokens"` +} + +type CompressResponse struct { + EventID string `json:"event_id"` + NumTokens int `json:"num_tokens"` +} + +type EvictResponse struct { + EventID string `json:"event_id"` + NumTokens int `json:"num_tokens"` +} + +func NewLMCacheClient(baseURL string, httpClient *resty.Client, retry RetryConfig, cb *CircuitBreaker) *LMCacheClient { + return &LMCacheClient{ + httpClient: httpClient, + baseURL: strings.TrimRight(baseURL, "/"), + retry: retry, + circuitBreaker: cb, + } +} + +func (lc *LMCacheClient) Pin(ctx context.Context, req *PinRequest) (*PinResponse, error) { + var resp PinResponse + if err := lc.doRequestWithRetry(ctx, "/pin", req, &resp, "pin"); err != nil { + return nil, err + } + return &resp, nil +} + +func (lc *LMCacheClient) Lookup(ctx context.Context, req *LookupRequest) (*LookupResponse, error) { + var resp LookupResponse + if err := lc.doRequestWithRetry(ctx, "/lookup", req, &resp, "lookup"); err != nil { + return nil, err + } + return &resp, nil +} + +func (lc *LMCacheClient) Compress(ctx context.Context, req *CompressRequest) (*CompressResponse, error) { + var resp CompressResponse + if err := lc.doRequestWithRetry(ctx, "/compress", req, &resp, "compress"); err != nil { + return nil, err + } + return &resp, nil +} + +func (lc *LMCacheClient) Evict(ctx context.Context, req *EvictRequest) (*EvictResponse, error) { + var resp EvictResponse + if err := lc.doRequestWithRetry(ctx, "/evict", req, &resp, "evict"); err != nil { + return nil, err + } + return &resp, nil +} + +func (lc *LMCacheClient) doRequestWithRetry(ctx context.Context, path string, payload any, out any, op string) error { + var lastErr error + for attempt := 0; attempt < lc.retry.MaxAttempts; attempt++ { Review Comment: `retry.max_attempts` can be configured as 0 (Validate allows >= 0), but when it is 0 this retry loop does zero iterations and returns `lastErr` which is `nil`, causing callers to treat the request as successful without making any HTTP call. Ensure at least one attempt is executed (e.g., treat 0 as 1 or require `max_attempts >= 1`). ```suggestion maxAttempts := lc.retry.MaxAttempts if maxAttempts < 1 { maxAttempts = 1 } for attempt := 0; attempt < maxAttempts; attempt++ { ``` ########## pkg/filter/ai/kvcache/strategy.go: ########## @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kvcache + +import ( + "context" +) + +type CacheStrategy struct { + config CacheStrategyConfig + loadMonitor *LoadMonitor + lmcacheClient *LMCacheClient + tokenManager *TokenManager +} + +type StrategyDecision struct { + ShouldCompress bool + ShouldPin bool + ShouldEvict bool + Reason string +} + +func NewCacheStrategy(cfg CacheStrategyConfig, client *LMCacheClient, tokenManager *TokenManager) *CacheStrategy { + return &CacheStrategy{ + config: cfg, + loadMonitor: NewLoadMonitor(), + lmcacheClient: client, + tokenManager: tokenManager, + } +} + +func (cs *CacheStrategy) RecordRequest() { + if cs == nil || cs.loadMonitor == nil { + return + } + cs.loadMonitor.RecordRequest() +} + +func (cs *CacheStrategy) MakeDecision(_ context.Context, cacheStatus *LookupResponse, model string, prompt string) *StrategyDecision { + if cs == nil { + return &StrategyDecision{} + } + decision := &StrategyDecision{} + metrics := cs.loadMonitor.Snapshot() + + if cs.config.EnableCompression && cs.config.LoadThreshold > 0 && + (metrics.CPUUsage >= cs.config.LoadThreshold || metrics.RequestRate >= cs.config.LoadThreshold) { Review Comment: `LoadThreshold` is validated as a 0..1 ratio, but `LoadMonitor.Snapshot()` sets `RequestRate` to requests/second. Comparing `metrics.RequestRate >= LoadThreshold` will almost always trip for any non-trivial traffic (e.g., threshold 0.7 means 0.7 req/s). Either change `LoadThreshold` semantics/validation to use req/s, or normalize `RequestRate` into 0..1 before comparing. ```suggestion metrics.CPUUsage >= cs.config.LoadThreshold { ``` ########## pkg/filter/ai/kvcache/filter.go: ########## @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kvcache + +import ( + "context" + "net/http" +) + +import ( + "github.com/go-resty/resty/v2" +) + +import ( + "github.com/apache/dubbo-go-pixiu/pkg/common/constant" + "github.com/apache/dubbo-go-pixiu/pkg/common/extension/filter" + contexthttp "github.com/apache/dubbo-go-pixiu/pkg/context/http" + "github.com/apache/dubbo-go-pixiu/pkg/logger" +) + +const ( + Kind = constant.AIKVCacheFilter +) + +func init() { + filter.RegisterHttpFilter(&Plugin{}) +} + +type ( + Plugin struct{} + + FilterFactory struct { + cfg *Config + httpClient *http.Client + resty *resty.Client + tokenManager *TokenManager + lmcacheClient *LMCacheClient + cacheStrategy *CacheStrategy + } + + Filter struct { + cfg *Config + tokenManager *TokenManager + lmcacheClient *LMCacheClient + cacheStrategy *CacheStrategy + } +) + +func (p *Plugin) Kind() string { return Kind } + +func (p *Plugin) CreateFilterFactory() (filter.HttpFilterFactory, error) { + return &FilterFactory{cfg: &Config{}}, nil +} + +func (factory *FilterFactory) Config() any { return factory.cfg } + +func (factory *FilterFactory) Apply() error { + factory.cfg.ApplyDefaults() + if err := factory.cfg.Validate(); err != nil { + return err + } + cfg := factory.cfg + factory.httpClient = &http.Client{ + Timeout: cfg.RequestTimeout, + Transport: &http.Transport{ + MaxIdleConns: cfg.MaxIdleConns, + MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, + MaxConnsPerHost: cfg.MaxConnsPerHost, + }, + } + factory.resty = resty.NewWithClient(factory.httpClient). + SetTimeout(cfg.RequestTimeout) + + cbToken := NewCircuitBreaker(cfg.CircuitBreaker) + cbLMCache := NewCircuitBreaker(cfg.CircuitBreaker) + factory.tokenManager = NewTokenManager(cfg.VLLMEndpoint, factory.resty, cfg.TokenCache, cbToken, cfg.HotWindow, cfg.HotMaxRecords) + factory.lmcacheClient = NewLMCacheClient(cfg.LMCacheEndpoint, factory.resty, cfg.Retry, cbLMCache) + factory.cacheStrategy = NewCacheStrategy(cfg.CacheStrategy, factory.lmcacheClient, factory.tokenManager) + return nil +} + +func (factory *FilterFactory) PrepareFilterChain(_ *contexthttp.HttpContext, chain filter.FilterChain) error { + f := &Filter{ + cfg: factory.cfg, + tokenManager: factory.tokenManager, + lmcacheClient: factory.lmcacheClient, + cacheStrategy: factory.cacheStrategy, + } + chain.AppendDecodeFilters(f) + return nil +} + +func (f *Filter) Decode(hc *contexthttp.HttpContext) filter.FilterStatus { + if f.cfg == nil || !f.cfg.Enabled { + return filter.Continue + } + if f.cacheStrategy != nil { Review Comment: This PR adds a new filter with non-trivial routing + async cache-management behavior, but there are no unit tests covering it. Adding focused tests (prompt/model extraction, preferred-endpoint routing override, LMCache retry edge cases) would help prevent regressions. ########## pkg/filter/ai/kvcache/filter.go: ########## @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kvcache + +import ( + "context" + "net/http" +) + +import ( + "github.com/go-resty/resty/v2" +) + +import ( + "github.com/apache/dubbo-go-pixiu/pkg/common/constant" + "github.com/apache/dubbo-go-pixiu/pkg/common/extension/filter" + contexthttp "github.com/apache/dubbo-go-pixiu/pkg/context/http" + "github.com/apache/dubbo-go-pixiu/pkg/logger" +) + +const ( + Kind = constant.AIKVCacheFilter +) + +func init() { + filter.RegisterHttpFilter(&Plugin{}) +} + +type ( + Plugin struct{} + + FilterFactory struct { + cfg *Config + httpClient *http.Client + resty *resty.Client + tokenManager *TokenManager + lmcacheClient *LMCacheClient + cacheStrategy *CacheStrategy + } + + Filter struct { + cfg *Config + tokenManager *TokenManager + lmcacheClient *LMCacheClient + cacheStrategy *CacheStrategy + } +) + +func (p *Plugin) Kind() string { return Kind } + +func (p *Plugin) CreateFilterFactory() (filter.HttpFilterFactory, error) { + return &FilterFactory{cfg: &Config{}}, nil +} + +func (factory *FilterFactory) Config() any { return factory.cfg } + +func (factory *FilterFactory) Apply() error { + factory.cfg.ApplyDefaults() + if err := factory.cfg.Validate(); err != nil { + return err + } + cfg := factory.cfg + factory.httpClient = &http.Client{ + Timeout: cfg.RequestTimeout, + Transport: &http.Transport{ + MaxIdleConns: cfg.MaxIdleConns, + MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, + MaxConnsPerHost: cfg.MaxConnsPerHost, + }, + } + factory.resty = resty.NewWithClient(factory.httpClient). + SetTimeout(cfg.RequestTimeout) + + cbToken := NewCircuitBreaker(cfg.CircuitBreaker) + cbLMCache := NewCircuitBreaker(cfg.CircuitBreaker) + factory.tokenManager = NewTokenManager(cfg.VLLMEndpoint, factory.resty, cfg.TokenCache, cbToken, cfg.HotWindow, cfg.HotMaxRecords) + factory.lmcacheClient = NewLMCacheClient(cfg.LMCacheEndpoint, factory.resty, cfg.Retry, cbLMCache) + factory.cacheStrategy = NewCacheStrategy(cfg.CacheStrategy, factory.lmcacheClient, factory.tokenManager) + return nil +} + +func (factory *FilterFactory) PrepareFilterChain(_ *contexthttp.HttpContext, chain filter.FilterChain) error { + f := &Filter{ + cfg: factory.cfg, + tokenManager: factory.tokenManager, + lmcacheClient: factory.lmcacheClient, + cacheStrategy: factory.cacheStrategy, + } + chain.AppendDecodeFilters(f) + return nil +} + +func (f *Filter) Decode(hc *contexthttp.HttpContext) filter.FilterStatus { + if f.cfg == nil || !f.cfg.Enabled { + return filter.Continue + } + if f.cacheStrategy != nil { + f.cacheStrategy.RecordRequest() + } + body, err := readRequestBody(hc.Request) + if err != nil { + logger.Warnf("[KVCache] read request body failed: %v", err) + return filter.Continue + } + prompt, model, err := extractPromptAndModel(body) + if err != nil { + logger.Warnf("[KVCache] parse request body failed: %v", err) + return filter.Continue + } + if prompt == "" { + return filter.Continue + } + if model == "" { + model = f.cfg.DefaultModel + } + + f.tokenManager.RecordHot(model, prompt) + + cacheStatus, routed := f.tryRouteToCachedInstance(hc, model, prompt) + + ctx, cancel := context.WithTimeout(hc.Ctx, effectiveTimeout(hc, f.cfg)) + go func() { + defer cancel() + f.manageCache(ctx, model, prompt, body, cacheStatus, routed) + }() Review Comment: `context.WithTimeout(hc.Ctx, ...)` uses `hc.Ctx`, which is set to `context.Background()` in the HTTP connection manager, so these background cache-management goroutines won't be canceled when the client request is canceled. Consider deriving this from `hc.Request.Context()` (or ensuring `hc.Ctx` is set from it) so work is aborted promptly and goroutine buildup under load is avoided. ########## pkg/filter/ai/kvcache/token_manager.go: ########## @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kvcache + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" +) + +import ( + "github.com/go-resty/resty/v2" +) + +type TokenManager struct { + httpClient *resty.Client + endpoint string + cache sync.Map + config TokenCacheConfig + circuitBreaker *CircuitBreaker + + cacheSize int64 + hitCount int64 + missCount int64 + + hotWindow time.Duration + hotMax int + hotMu sync.Mutex + hotMap map[string][]time.Time +} + +type TokenizeRequest struct { + Model string `json:"model,omitempty"` + Prompt string `json:"prompt"` +} + +type TokenizeResponse struct { + Count int `json:"count"` + Tokens []int `json:"tokens"` + MaxLen int `json:"max_model_len"` +} + +type tokenCacheEntry struct { + tokens []int + expiresAt time.Time +} + +func NewTokenManager(endpoint string, httpClient *resty.Client, cfg TokenCacheConfig, cb *CircuitBreaker, hotWindow time.Duration, hotMax int) *TokenManager { + return &TokenManager{ + httpClient: httpClient, + endpoint: endpoint, + config: cfg, + circuitBreaker: cb, + hotWindow: hotWindow, + hotMax: hotMax, + hotMap: make(map[string][]time.Time), + } +} + +func (tm *TokenManager) GetTokens(ctx context.Context, model string, prompt string, rawBody []byte) ([]int, error) { + cacheKey := tm.cacheKey(model, prompt) + if tm.config.Enabled { + if tokens, ok := tm.loadCache(cacheKey); ok { + atomic.AddInt64(&tm.hitCount, 1) + return tokens, nil + } + atomic.AddInt64(&tm.missCount, 1) + } + + tokens, err := tm.tokenize(ctx, model, prompt, rawBody) + if err != nil { + return nil, err + } + + if tm.config.Enabled { + tm.storeCache(cacheKey, tokens) + } + return tokens, nil +} + +func (tm *TokenManager) GetCachedTokens(model string, prompt string) ([]int, bool) { + if !tm.config.Enabled { + return nil, false + } + cacheKey := tm.cacheKey(model, prompt) + tokens, ok := tm.loadCache(cacheKey) + if ok { + atomic.AddInt64(&tm.hitCount, 1) + } else { + atomic.AddInt64(&tm.missCount, 1) + } + return tokens, ok +} + +func (tm *TokenManager) InvalidateCache(model string, prompt string) { + cacheKey := tm.cacheKey(model, prompt) + if _, ok := tm.cache.Load(cacheKey); ok { + tm.cache.Delete(cacheKey) + atomic.AddInt64(&tm.cacheSize, -1) + } +} + +func (tm *TokenManager) GetCacheStats() CacheStats { + size := atomic.LoadInt64(&tm.cacheSize) + hit := atomic.LoadInt64(&tm.hitCount) + miss := atomic.LoadInt64(&tm.missCount) + total := hit + miss + var hitRate float64 + if total > 0 { + hitRate = float64(hit) / float64(total) + } + return CacheStats{ + Size: int(size), + HitRate: hitRate, + HitCount: hit, + MissCount: miss, + } +} + +func (tm *TokenManager) tokenize(ctx context.Context, model string, prompt string, rawBody []byte) ([]int, error) { + var tokens []int + err := tm.execute(ctx, func() error { + body, err := tm.buildTokenizeBody(model, prompt, rawBody) + if err != nil { + return err + } + resp, err := tm.doTokenizeRequest(ctx, body) + if err != nil { + return err + } + tokens = resp.Tokens + return nil + }) + if err != nil { + return nil, err + } + return tokens, nil +} + +func (tm *TokenManager) buildTokenizeBody(model string, prompt string, rawBody []byte) (any, error) { + if len(rawBody) > 0 { + return rawBody, nil + } + return TokenizeRequest{Model: model, Prompt: prompt}, nil +} + +func (tm *TokenManager) doTokenizeRequest(ctx context.Context, body any) (*TokenizeResponse, error) { + tokenizeURL := strings.TrimRight(tm.endpoint, "/") + "/tokenize" + resp, err := tm.httpClient.R(). + SetContext(ctx). + SetHeader("Content-Type", "application/json"). + SetBody(body). + Post(tokenizeURL) + if err != nil { + return nil, fmt.Errorf("call tokenize: %w", err) + } + if resp.StatusCode() < 200 || resp.StatusCode() >= 300 { + return nil, fmt.Errorf("tokenize status %d: %s", resp.StatusCode(), strings.TrimSpace(string(resp.Body()))) + } + var tokenResp TokenizeResponse + if err := json.Unmarshal(resp.Body(), &tokenResp); err != nil { + return nil, fmt.Errorf("decode tokenize response: %w", err) + } + return &tokenResp, nil +} + +func (tm *TokenManager) RecordHot(model string, prompt string) { + if tm == nil || tm.hotWindow <= 0 || model == "" || prompt == "" { + return + } + now := time.Now() + key := tm.cacheKey(model, prompt) + tm.hotMu.Lock() + defer tm.hotMu.Unlock() + entries := tm.hotMap[key] + entries = append(entries, now) + entries = trimHotWindow(entries, now, tm.hotWindow) + if tm.hotMax > 0 && len(entries) > tm.hotMax { + entries = entries[len(entries)-tm.hotMax:] + } + if len(entries) == 0 { + delete(tm.hotMap, key) + return + } + tm.hotMap[key] = entries +} + +func (tm *TokenManager) IsHot(model string, prompt string, threshold int) bool { + if tm == nil || tm.hotWindow <= 0 || threshold <= 0 || model == "" || prompt == "" { + return false + } + now := time.Now() + key := tm.cacheKey(model, prompt) + tm.hotMu.Lock() + defer tm.hotMu.Unlock() + entries := tm.hotMap[key] + if len(entries) == 0 { + return false + } + entries = trimHotWindow(entries, now, tm.hotWindow) + if tm.hotMax > 0 && len(entries) > tm.hotMax { + entries = entries[len(entries)-tm.hotMax:] + } + if len(entries) == 0 { + delete(tm.hotMap, key) + return false + } + tm.hotMap[key] = entries + return len(entries) >= threshold +} + +func trimHotWindow(entries []time.Time, now time.Time, window time.Duration) []time.Time { + if window <= 0 || len(entries) == 0 { + return entries + } + cutoff := now.Add(-window) + idx := 0 + for idx < len(entries) && entries[idx].Before(cutoff) { + idx++ + } + if idx == 0 { + return entries + } + return append([]time.Time(nil), entries[idx:]...) +} + +func (tm *TokenManager) execute(ctx context.Context, operation func() error) error { + if tm.circuitBreaker == nil { + return operation() + } + return tm.circuitBreaker.Execute(operation) +} + +func (tm *TokenManager) cacheKey(model string, prompt string) string { + if model == "" { + return prompt + } + return model + ":" + prompt Review Comment: `cacheKey` concatenates the full `prompt` into the in-memory cache key. For large prompts this can significantly increase memory usage (duplicate prompt strings across cache/hotMap) and GC pressure. Consider using a bounded-length key (e.g., hash of model+prompt) while keeping collision risk acceptable. ########## pkg/filter/ai/kvcache/handlers.go: ########## @@ -0,0 +1,183 @@ +package kvcache + +import ( + "bytes" + "context" Review Comment: This file is missing the standard Apache Software Foundation license header comment that appears at the top of other Go files in this repo. Please add the ASF header to keep licensing consistent. ########## pkg/filter/ai/kvcache/load_monitor.go: ########## @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kvcache + +import ( + "runtime" + "sync" + "time" +) + +type LoadMonitor struct { + window time.Duration + last time.Time + count int64 + rate float64 + mutex sync.Mutex +} + +func NewLoadMonitor() *LoadMonitor { + return &LoadMonitor{ + window: time.Second, + last: time.Now(), + } +} + +func (lm *LoadMonitor) RecordRequest() { + if lm == nil { + return + } + lm.mutex.Lock() + lm.count++ + lm.mutex.Unlock() +} + +func (lm *LoadMonitor) Snapshot() LoadMetrics { + if lm == nil { + return LoadMetrics{} + } + lm.mutex.Lock() + defer lm.mutex.Unlock() + now := time.Now() + elapsed := now.Sub(lm.last) + if elapsed >= lm.window && elapsed > 0 { + lm.rate = float64(lm.count) / elapsed.Seconds() + lm.count = 0 + lm.last = now + } + var ms runtime.MemStats + runtime.ReadMemStats(&ms) + memUsage := 0.0 + if ms.Sys > 0 { + memUsage = float64(ms.Alloc) / float64(ms.Sys) + } + return LoadMetrics{ + CPUUsage: 0, + MemoryUsage: memUsage, + RequestRate: lm.rate, + } Review Comment: `Snapshot()` always returns `CPUUsage: 0`, but the strategy checks CPU usage against `LoadThreshold`. This makes CPU-based decisions impossible and can mislead operators configuring `load_threshold`. Either implement CPU utilization collection or remove CPU from the decision criteria/config. ########## pkg/filter/ai/kvcache/handlers.go: ########## @@ -0,0 +1,183 @@ +package kvcache + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "strings" + "time" +) + +import ( + "github.com/apache/dubbo-go-pixiu/pkg/common/constant" + contexthttp "github.com/apache/dubbo-go-pixiu/pkg/context/http" + "github.com/apache/dubbo-go-pixiu/pkg/logger" +) + +func (f *Filter) manageCache(ctx context.Context, model string, prompt string, rawBody []byte, cacheStatus *LookupResponse, lookupDone bool) { + if ctx.Err() != nil { + return + } + tokens, err := f.tokenManager.GetTokens(ctx, model, prompt, rawBody) + if err != nil { + logger.Warnf("[KVCache] tokenize failed: %v", err) + return + } + if ctx.Err() != nil { + return + } + if !lookupDone || cacheStatus == nil { + cacheStatus, err = f.lmcacheClient.Lookup(ctx, &LookupRequest{Tokens: tokens}) + if err != nil { + logger.Warnf("[KVCache] lookup failed: %v", err) + return + } + } + decision := f.cacheStrategy.MakeDecision(ctx, cacheStatus, model, prompt) + if ctx.Err() != nil { + return + } + if err := f.cacheStrategy.ExecuteDecision(ctx, decision, tokens); err != nil { + logger.Warnf("[KVCache] execute strategy failed: %v", err) + } +} + +func readRequestBody(req *http.Request) ([]byte, error) { + if req == nil || req.Body == nil { + return nil, nil + } + if req.GetBody != nil { + reader, err := req.GetBody() + if err == nil { + defer reader.Close() + return io.ReadAll(reader) + } + } + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(bodyBytes)), nil + } + return bodyBytes, nil +} + +func extractPromptAndModel(body []byte) (string, string, error) { + if len(body) == 0 { + return "", "", nil + } + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return "", "", err + } + model, _ := payload["model"].(string) + prompt := coercePrompt(payload["prompt"]) + if prompt == "" { + prompt = extractPromptFromMessages(payload["messages"]) + } + return strings.TrimSpace(prompt), model, nil +} + +func coercePrompt(value any) string { + switch v := value.(type) { + case string: + return v + case []any: + parts := make([]string, 0, len(v)) + for _, item := range v { + if str, ok := item.(string); ok { + parts = append(parts, str) + } + } + return strings.Join(parts, "\n") + default: + return "" + } +} + +func extractPromptFromMessages(value any) string { + msgs, ok := value.([]any) + if !ok { + return "" + } + parts := make([]string, 0, len(msgs)) + for _, msg := range msgs { + msgMap, ok := msg.(map[string]any) + if !ok { + continue + } + if content, ok := msgMap["content"].(string); ok { + parts = append(parts, content) + } + } + return strings.Join(parts, "\n") +} + +func selectPreferredInstanceID(resp *LookupResponse) string { + if resp == nil || len(resp.LayoutInfo) == 0 { + return "" + } + var ( + selected string + maxCount int + ) + for instanceID, layout := range resp.LayoutInfo { + if layout.TokenCount > maxCount || selected == "" { + selected = instanceID + maxCount = layout.TokenCount + } + } + return selected +} + +func effectiveTimeout(hc *contexthttp.HttpContext, cfg *Config) time.Duration { + if cfg == nil { + return 0 + } + timeout := cfg.RequestTimeout + if hc != nil && hc.Timeout > 0 && (timeout <= 0 || hc.Timeout < timeout) { + timeout = hc.Timeout + } + if timeout <= 0 { + return 2 * time.Second + } + return timeout +} + +func (f *Filter) tryRouteToCachedInstance(hc *contexthttp.HttpContext, model string, prompt string) (*LookupResponse, bool) { + if f == nil || f.tokenManager == nil || f.lmcacheClient == nil { + return nil, false + } + tokens, ok := f.tokenManager.GetCachedTokens(model, prompt) + if !ok || len(tokens) == 0 { + logger.Debugf("[KVCache] routing lookup skipped: token cache miss") + return nil, false + } + timeout := effectiveTimeout(hc, f.cfg) + if f.cfg != nil && f.cfg.LookupRoutingTimeout > 0 && f.cfg.LookupRoutingTimeout < timeout { + timeout = f.cfg.LookupRoutingTimeout + } + ctx, cancel := context.WithTimeout(hc.Ctx, timeout) Review Comment: `context.WithTimeout(hc.Ctx, ...)` uses `hc.Ctx`, which is `context.Background()` in the HTTP manager, so lookup work won't be canceled when the client disconnects. Using `hc.Request.Context()` here (or setting `hc.Ctx` from it) would allow early cancellation and reduce unnecessary LMCache calls. ########## pkg/filter/ai/kvcache/token_manager.go: ########## @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kvcache + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" +) + +import ( + "github.com/go-resty/resty/v2" +) + +type TokenManager struct { + httpClient *resty.Client + endpoint string + cache sync.Map + config TokenCacheConfig + circuitBreaker *CircuitBreaker + + cacheSize int64 + hitCount int64 + missCount int64 + + hotWindow time.Duration + hotMax int + hotMu sync.Mutex + hotMap map[string][]time.Time +} + +type TokenizeRequest struct { + Model string `json:"model,omitempty"` + Prompt string `json:"prompt"` +} + +type TokenizeResponse struct { + Count int `json:"count"` + Tokens []int `json:"tokens"` + MaxLen int `json:"max_model_len"` +} + +type tokenCacheEntry struct { + tokens []int + expiresAt time.Time +} + +func NewTokenManager(endpoint string, httpClient *resty.Client, cfg TokenCacheConfig, cb *CircuitBreaker, hotWindow time.Duration, hotMax int) *TokenManager { + return &TokenManager{ + httpClient: httpClient, + endpoint: endpoint, + config: cfg, + circuitBreaker: cb, + hotWindow: hotWindow, + hotMax: hotMax, + hotMap: make(map[string][]time.Time), + } +} + +func (tm *TokenManager) GetTokens(ctx context.Context, model string, prompt string, rawBody []byte) ([]int, error) { + cacheKey := tm.cacheKey(model, prompt) + if tm.config.Enabled { + if tokens, ok := tm.loadCache(cacheKey); ok { + atomic.AddInt64(&tm.hitCount, 1) + return tokens, nil + } + atomic.AddInt64(&tm.missCount, 1) + } + + tokens, err := tm.tokenize(ctx, model, prompt, rawBody) + if err != nil { + return nil, err + } + + if tm.config.Enabled { + tm.storeCache(cacheKey, tokens) + } + return tokens, nil +} + +func (tm *TokenManager) GetCachedTokens(model string, prompt string) ([]int, bool) { + if !tm.config.Enabled { + return nil, false + } + cacheKey := tm.cacheKey(model, prompt) + tokens, ok := tm.loadCache(cacheKey) + if ok { + atomic.AddInt64(&tm.hitCount, 1) + } else { + atomic.AddInt64(&tm.missCount, 1) + } + return tokens, ok +} + +func (tm *TokenManager) InvalidateCache(model string, prompt string) { + cacheKey := tm.cacheKey(model, prompt) + if _, ok := tm.cache.Load(cacheKey); ok { + tm.cache.Delete(cacheKey) + atomic.AddInt64(&tm.cacheSize, -1) + } +} + +func (tm *TokenManager) GetCacheStats() CacheStats { + size := atomic.LoadInt64(&tm.cacheSize) + hit := atomic.LoadInt64(&tm.hitCount) + miss := atomic.LoadInt64(&tm.missCount) + total := hit + miss + var hitRate float64 + if total > 0 { + hitRate = float64(hit) / float64(total) + } + return CacheStats{ + Size: int(size), + HitRate: hitRate, + HitCount: hit, + MissCount: miss, + } +} + +func (tm *TokenManager) tokenize(ctx context.Context, model string, prompt string, rawBody []byte) ([]int, error) { + var tokens []int + err := tm.execute(ctx, func() error { + body, err := tm.buildTokenizeBody(model, prompt, rawBody) + if err != nil { + return err + } + resp, err := tm.doTokenizeRequest(ctx, body) + if err != nil { + return err + } + tokens = resp.Tokens + return nil + }) + if err != nil { + return nil, err + } + return tokens, nil +} + +func (tm *TokenManager) buildTokenizeBody(model string, prompt string, rawBody []byte) (any, error) { + if len(rawBody) > 0 { + return rawBody, nil + } + return TokenizeRequest{Model: model, Prompt: prompt}, nil +} + +func (tm *TokenManager) doTokenizeRequest(ctx context.Context, body any) (*TokenizeResponse, error) { + tokenizeURL := strings.TrimRight(tm.endpoint, "/") + "/tokenize" + resp, err := tm.httpClient.R(). + SetContext(ctx). + SetHeader("Content-Type", "application/json"). + SetBody(body). + Post(tokenizeURL) + if err != nil { + return nil, fmt.Errorf("call tokenize: %w", err) + } + if resp.StatusCode() < 200 || resp.StatusCode() >= 300 { + return nil, fmt.Errorf("tokenize status %d: %s", resp.StatusCode(), strings.TrimSpace(string(resp.Body()))) + } + var tokenResp TokenizeResponse + if err := json.Unmarshal(resp.Body(), &tokenResp); err != nil { + return nil, fmt.Errorf("decode tokenize response: %w", err) + } + return &tokenResp, nil +} + +func (tm *TokenManager) RecordHot(model string, prompt string) { + if tm == nil || tm.hotWindow <= 0 || model == "" || prompt == "" { + return + } + now := time.Now() + key := tm.cacheKey(model, prompt) + tm.hotMu.Lock() + defer tm.hotMu.Unlock() + entries := tm.hotMap[key] + entries = append(entries, now) + entries = trimHotWindow(entries, now, tm.hotWindow) + if tm.hotMax > 0 && len(entries) > tm.hotMax { + entries = entries[len(entries)-tm.hotMax:] + } + if len(entries) == 0 { + delete(tm.hotMap, key) + return + } + tm.hotMap[key] = entries +} + +func (tm *TokenManager) IsHot(model string, prompt string, threshold int) bool { + if tm == nil || tm.hotWindow <= 0 || threshold <= 0 || model == "" || prompt == "" { + return false + } + now := time.Now() + key := tm.cacheKey(model, prompt) + tm.hotMu.Lock() + defer tm.hotMu.Unlock() + entries := tm.hotMap[key] + if len(entries) == 0 { + return false + } + entries = trimHotWindow(entries, now, tm.hotWindow) + if tm.hotMax > 0 && len(entries) > tm.hotMax { + entries = entries[len(entries)-tm.hotMax:] + } + if len(entries) == 0 { + delete(tm.hotMap, key) + return false + } + tm.hotMap[key] = entries + return len(entries) >= threshold +} + +func trimHotWindow(entries []time.Time, now time.Time, window time.Duration) []time.Time { + if window <= 0 || len(entries) == 0 { + return entries + } + cutoff := now.Add(-window) + idx := 0 + for idx < len(entries) && entries[idx].Before(cutoff) { + idx++ + } + if idx == 0 { + return entries + } + return append([]time.Time(nil), entries[idx:]...) +} + +func (tm *TokenManager) execute(ctx context.Context, operation func() error) error { + if tm.circuitBreaker == nil { + return operation() + } + return tm.circuitBreaker.Execute(operation) +} + +func (tm *TokenManager) cacheKey(model string, prompt string) string { + if model == "" { + return prompt + } + return model + ":" + prompt +} + +func (tm *TokenManager) loadCache(key string) ([]int, bool) { + entryAny, ok := tm.cache.Load(key) + if !ok { + return nil, false + } + entry, ok := entryAny.(*tokenCacheEntry) + if !ok { + tm.cache.Delete(key) + atomic.AddInt64(&tm.cacheSize, -1) + return nil, false + } + if tm.config.TTL > 0 && time.Now().After(entry.expiresAt) { + tm.cache.Delete(key) + atomic.AddInt64(&tm.cacheSize, -1) + return nil, false Review Comment: `cacheSize` is decremented on delete paths that can run concurrently (e.g., multiple goroutines can observe an entry as expired and both call `Delete` + `AddInt64(-1)`), which can make `cacheSize` drift/turn negative and break the max-size eviction loop. Consider tracking size with a guarded map (mutex + map) or only decrementing when you can prove a delete actually removed a present entry (e.g., `LoadAndDelete` in Go 1.20+). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
