This is an automated email from the ASF dual-hosted git repository.

liuhongyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shenyu.git


The following commit(s) were added to refs/heads/master by this push:
     new 13c70320e7 [feat] optimize aiTokenLimiterPlugin for streaming tokens 
(#6055)
13c70320e7 is described below

commit 13c70320e708569a6fc95dde61f70db99c6183b1
Author: HY-love-sleep <[email protected]>
AuthorDate: Thu Jul 17 17:20:36 2025 +0800

    [feat] optimize aiTokenLimiterPlugin for streaming tokens (#6055)
    
    * fix: optimize aiTokenLimiterPlugin for streaming tokens
    
    * chore: java format
    
    * chore: code review by copilot
---
 .../apache/shenyu/common/constant/Constants.java   |  10 ++
 .../plugin/ai/common/strategy/openai/OpenAI.java   |   8 +-
 .../ai/token/limiter/AiTokenLimiterPlugin.java     | 139 ++++++++++++++-------
 3 files changed, 112 insertions(+), 45 deletions(-)

diff --git 
a/shenyu-common/src/main/java/org/apache/shenyu/common/constant/Constants.java 
b/shenyu-common/src/main/java/org/apache/shenyu/common/constant/Constants.java
index 25136e5321..f8a25531f1 100644
--- 
a/shenyu-common/src/main/java/org/apache/shenyu/common/constant/Constants.java
+++ 
b/shenyu-common/src/main/java/org/apache/shenyu/common/constant/Constants.java
@@ -994,6 +994,16 @@ public interface Constants {
      * The constant USAGE.
      */
     String USAGE = "usage";
+
+    /**
+     * The include_usage for stream.
+     */
+    String INCLUDE_USAGE = "include_usage";
+
+    /**
+     * The stream_options.
+     */
+    String STREAM_OPTIONS = "stream_options";
     
     /**
      * The constant COMPLETION_TOKENS.
diff --git 
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/openai/OpenAI.java
 
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/openai/OpenAI.java
index a81416c755..aa1a4f30e6 100644
--- 
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/openai/OpenAI.java
+++ 
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/openai/OpenAI.java
@@ -35,6 +35,7 @@ import org.springframework.http.codec.HttpMessageReader;
 import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
 
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -66,7 +67,7 @@ public class OpenAI implements AiModel {
                     return Mono.error(error);
                 });
     }
-    
+
     @Override
     public Long getCompletionTokens(final String responseBody) {
         try {
@@ -101,6 +102,11 @@ public class OpenAI implements AiModel {
         Map<String, Object> requestBodyMap = 
GsonUtils.getInstance().convertToMap(originalBody);
         requestBodyMap.put(Constants.MODEL, aiCommonConfig.getModel());
         requestBodyMap.put(Constants.STREAM, aiCommonConfig.getStream());
+        if (aiCommonConfig.getStream()) {
+            Map<String, Object> streamOptions = new HashMap<>();
+            streamOptions.put(Constants.INCLUDE_USAGE, true);
+            requestBodyMap.put(Constants.STREAM_OPTIONS, streamOptions);
+        }
         return GsonUtils.getInstance().toJson(requestBodyMap);
     }
     
diff --git 
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
 
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
index c7f66d2ae9..0765ed9695 100644
--- 
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
+++ 
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
@@ -17,6 +17,8 @@
 
 package org.apache.shenyu.plugin.ai.token.limiter;
 
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.shenyu.common.constant.Constants;
 import org.apache.shenyu.common.dto.RuleData;
@@ -24,7 +26,6 @@ import org.apache.shenyu.common.dto.SelectorData;
 import org.apache.shenyu.common.dto.convert.rule.AiTokenLimiterHandle;
 import org.apache.shenyu.common.enums.AiTokenLimiterEnum;
 import org.apache.shenyu.common.enums.PluginEnum;
-import org.apache.shenyu.plugin.ai.common.strategy.AiModel;
 import 
org.apache.shenyu.plugin.ai.token.limiter.handler.AiTokenLimiterPluginHandler;
 import org.apache.shenyu.plugin.api.ShenyuPluginChain;
 import org.apache.shenyu.plugin.api.result.ShenyuResultEnum;
@@ -61,6 +62,8 @@ import java.util.Objects;
 import java.util.Optional;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Consumer;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 import java.util.zip.DataFormatException;
 import java.util.zip.Inflater;
 
@@ -68,38 +71,40 @@ import java.util.zip.Inflater;
  * Shenyu ai token limiter plugin.
  */
 public class AiTokenLimiterPlugin extends AbstractShenyuPlugin {
-    
+
     private static final Logger LOG = 
LoggerFactory.getLogger(AiTokenLimiterPlugin.class);
-    
+
     private static final String REDIS_KEY_PREFIX = "SHENYU:AI:TOKENLIMIT:";
-    
+
+    private static final Pattern COMPLETION_TOKENS_PATTERN = 
Pattern.compile("\"completion_tokens\"\\s*:\\s*(\\d+)");
+
     @Override
     protected Mono<Void> doExecute(final ServerWebExchange exchange, final 
ShenyuPluginChain chain,
                                    final SelectorData selector, final RuleData 
rule) {
-        
+
         AiTokenLimiterHandle aiTokenLimiterHandle = 
AiTokenLimiterPluginHandler.CACHED_HANDLE.get().obtainHandle(CacheKeyUtils.INST.getKey(rule));
-        
+
         if (Objects.isNull(aiTokenLimiterHandle)) {
             return chain.execute(exchange);
         }
-        
+
         ReactiveRedisTemplate reactiveRedisTemplate = 
AiTokenLimiterPluginHandler.REDIS_CACHED_HANDLE.get().obtainHandle(PluginEnum.AI_TOKEN_LIMITER.getName());
         Assert.notNull(reactiveRedisTemplate, "reactiveRedisTemplate is null");
-        
+
         // generate redis key
         String tokenLimitType = aiTokenLimiterHandle.getAiTokenLimitType();
         String keyName = aiTokenLimiterHandle.getKeyName();
         Long tokenLimit = aiTokenLimiterHandle.getTokenLimit();
         Long timeWindowSeconds = aiTokenLimiterHandle.getTimeWindowSeconds();
-        
+
         String cacheKey = REDIS_KEY_PREFIX + getCacheKey(exchange, 
tokenLimitType, keyName);
-        
+
         final AiStatisticServerHttpResponse loggingServerHttpResponse = new 
AiStatisticServerHttpResponse(exchange, exchange.getResponse(),
                 tokens -> recordTokensUsage(reactiveRedisTemplate,
                         cacheKey,
                         tokens,
                         timeWindowSeconds));
-        
+
         // check if the request is allowed
         return isAllowed(reactiveRedisTemplate, cacheKey, tokenLimit)
                 .flatMap(allowed -> {
@@ -114,12 +119,12 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
                     ServerWebExchange mutatedExchange = exchange.mutate()
                             .response(loggingServerHttpResponse)
                             .build();
-                    
+
                     return chain.execute(mutatedExchange);
                 });
-        
+
     }
-    
+
     /**
      * Check if the request is allowed based on rate limiting rules.
      *
@@ -129,7 +134,7 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
      * @return whether the request is allowed
      */
     private Mono<Boolean> isAllowed(final ReactiveRedisTemplate 
reactiveRedisTemplate, final String cacheKey, final Long tokenLimit) {
-        
+
         return reactiveRedisTemplate.opsForValue().get(cacheKey)
                 .defaultIfEmpty(0L)
                 .flatMap(currentTokens -> {
@@ -139,7 +144,7 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
                     return Mono.just(true);
                 });
     }
-    
+
     /**
      * Get the cache key based on the configured key resolver type.
      *
@@ -165,7 +170,7 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
             }
             default -> exchange.getAttribute(Constants.CONTEXT_PATH);
         };
-        
+
         return StringUtils.isBlank(key) ? "" : key;
     }
 
@@ -176,38 +181,49 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
                 .flatMap(currentValue -> 
reactiveRedisTemplate.expire(cacheKey, Duration.ofSeconds(windowSeconds)))
                 .subscribe();
     }
-    
+
     @Override
     public int getOrder() {
         return PluginEnum.AI_TOKEN_LIMITER.getCode();
     }
-    
+
     @Override
     public String named() {
         return PluginEnum.AI_TOKEN_LIMITER.getName();
     }
-    
+
     static class AiStatisticServerHttpResponse extends 
ServerHttpResponseDecorator {
-        
+        private static final ObjectMapper MAPPER = new ObjectMapper();
+
         private final ServerWebExchange exchange;
-        
+
         private final ServerHttpResponse serverHttpResponse;
-        
+
         private final Consumer<Long> tokensRecorder;
-        
+
+        private final AtomicBoolean streamingUsageRecorded = new 
AtomicBoolean(false);
+
         AiStatisticServerHttpResponse(final ServerWebExchange exchange, final 
ServerHttpResponse delegate, final Consumer<Long> tokensRecorder) {
             super(delegate);
             this.exchange = exchange;
             this.serverHttpResponse = delegate;
             this.tokensRecorder = tokensRecorder;
         }
-        
+
         @Override
         @NonNull
         public Mono<Void> writeWith(@NonNull final Publisher<? extends 
DataBuffer> body) {
             return super.writeWith(appendResponse(body));
         }
 
+        @Override
+        @NonNull
+        public Mono<Void> writeAndFlushWith(@NonNull final Publisher<? extends 
Publisher<? extends DataBuffer>> body) {
+            Flux<? extends Publisher<? extends DataBuffer>> intercepted = 
Flux.from(body)
+                    .map(this::appendResponse);
+            return super.writeAndFlushWith(intercepted);
+        }
+
         @NonNull
         private Flux<? extends DataBuffer> appendResponse(final Publisher<? 
extends DataBuffer> body) {
             BodyWriter writer = new BodyWriter();
@@ -228,42 +244,77 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
                                 byte[] inBytes = new byte[ro.remaining()];
                                 ro.get(inBytes);
 
+                                byte[] processedBytes;
                                 if (isGzip) {
                                     int offset = 0;
-                                    int len = inBytes.length;
-                                    if (!headerSkipped.get()) {
+                                    if (headerSkipped.compareAndSet(false, 
true)) {
                                         offset = skipGzipHeader(inBytes);
-                                        headerSkipped.set(true);
                                     }
-                                    inflater.setInput(inBytes, offset, len - 
offset);
+                                    inflater.setInput(inBytes, offset, 
inBytes.length - offset);
+                                    ByteArrayOutputStream baos = new 
ByteArrayOutputStream();
                                     try {
                                         int cnt;
                                         while ((cnt = 
inflater.inflate(outBuf)) > 0) {
-                                            
writer.write(ByteBuffer.wrap(outBuf, 0, cnt));
+                                            baos.write(outBuf, 0, cnt);
                                         }
                                     } catch (DataFormatException ex) {
-                                        LOG.error("inflater decompression 
failed", ex);
+                                        LOG.error("Inflater decompression 
failed", ex);
                                     }
+                                    processedBytes = baos.toByteArray();
                                 } else {
-                                    writer.write(ro);
+                                    processedBytes = inBytes;
+                                }
+                                String chunk = new String(processedBytes, 
StandardCharsets.UTF_8);
+                                for (String line : chunk.split("\\r?\\n")) {
+                                    if (!line.startsWith("data:")) {
+                                        continue;
+                                    }
+                                    String payload = 
line.substring("data:".length()).trim();
+                                    if (payload.isEmpty() || 
"[DONE]".equals(payload)) {
+                                        continue;
+                                    }
+                                    if (!payload.startsWith("{")) {
+                                        continue;
+                                    }
+                                    try {
+                                        JsonNode node = 
MAPPER.readTree(payload);
+                                        JsonNode usage = 
node.get(Constants.USAGE);
+                                        if (Objects.nonNull(usage) && 
usage.has(Constants.COMPLETION_TOKENS)) {
+                                            long c = 
usage.get(Constants.COMPLETION_TOKENS).asLong();
+                                            tokensRecorder.accept(c);
+                                            streamingUsageRecorded.set(true);
+                                        }
+                                    } catch (Exception e) {
+                                        LOG.error("Failed to parse AI response 
JSON payload", e);
+                                    }
                                 }
+                                writer.write(ByteBuffer.wrap(processedBytes));
                             });
                         } catch (Exception e) {
                             LOG.error("read dataBuffer error", e);
                         }
                     })
                     .doFinally(signal -> {
-                        // release inflater
                         if (Objects.nonNull(inflater)) {
                             inflater.end();
                         }
-                        String responseBody = writer.output();
-                        AiModel aiModel = 
exchange.getAttribute(Constants.AI_MODEL);
-                        long tokens = 
Objects.requireNonNull(aiModel).getCompletionTokens(responseBody);
-                        tokensRecorder.accept(tokens);
+                        if (!streamingUsageRecorded.get()) {
+                            String sse = writer.output();
+                            long usageTokens = extractUsageTokensFromSse(sse);
+                            tokensRecorder.accept(usageTokens);
+                        }
                     });
         }
 
+        private long extractUsageTokensFromSse(final String sse) {
+            Matcher m = COMPLETION_TOKENS_PATTERN.matcher(sse);
+            long last = 0L;
+            while (m.find()) {
+                last = Long.parseLong(m.group(1));
+            }
+            return last;
+        }
+
         private int skipGzipHeader(final byte[] b) {
             int pos = 10;
             int flg = b[3] & 0xFF;
@@ -294,15 +345,15 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
         }
 
     }
-    
+
     static class BodyWriter {
-        
+
         private final ByteArrayOutputStream stream = new 
ByteArrayOutputStream();
-        
+
         private final WritableByteChannel channel = 
Channels.newChannel(stream);
-        
+
         private final AtomicBoolean isClosed = new AtomicBoolean(false);
-        
+
         void write(final ByteBuffer buffer) {
             if (!isClosed.get()) {
                 try {
@@ -313,11 +364,11 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
                 }
             }
         }
-        
+
         boolean isEmpty() {
             return stream.size() == 0;
         }
-        
+
         String output() {
             try {
                 isClosed.compareAndSet(false, true);

Reply via email to