This is an automated email from the ASF dual-hosted git repository.
liuhongyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shenyu.git
The following commit(s) were added to refs/heads/master by this push:
new 13c70320e7 [feat] optimize aiTokenLimiterPlugin for streaming tokens
(#6055)
13c70320e7 is described below
commit 13c70320e708569a6fc95dde61f70db99c6183b1
Author: HY-love-sleep <[email protected]>
AuthorDate: Thu Jul 17 17:20:36 2025 +0800
[feat] optimize aiTokenLimiterPlugin for streaming tokens (#6055)
* fix: optimize aiTokenLimiterPlugin for streaming tokens
* chore: java format
* chore: code review by copilot
---
.../apache/shenyu/common/constant/Constants.java | 10 ++
.../plugin/ai/common/strategy/openai/OpenAI.java | 8 +-
.../ai/token/limiter/AiTokenLimiterPlugin.java | 139 ++++++++++++++-------
3 files changed, 112 insertions(+), 45 deletions(-)
diff --git
a/shenyu-common/src/main/java/org/apache/shenyu/common/constant/Constants.java
b/shenyu-common/src/main/java/org/apache/shenyu/common/constant/Constants.java
index 25136e5321..f8a25531f1 100644
---
a/shenyu-common/src/main/java/org/apache/shenyu/common/constant/Constants.java
+++
b/shenyu-common/src/main/java/org/apache/shenyu/common/constant/Constants.java
@@ -994,6 +994,16 @@ public interface Constants {
* The constant USAGE.
*/
String USAGE = "usage";
+
+ /**
+ * The include_usage for stream.
+ */
+ String INCLUDE_USAGE = "include_usage";
+
+ /**
+ * The stream_options.
+ */
+ String STREAM_OPTIONS = "stream_options";
/**
* The constant COMPLETION_TOKENS.
diff --git
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/openai/OpenAI.java
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/openai/OpenAI.java
index a81416c755..aa1a4f30e6 100644
---
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/openai/OpenAI.java
+++
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/openai/OpenAI.java
@@ -35,6 +35,7 @@ import org.springframework.http.codec.HttpMessageReader;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -66,7 +67,7 @@ public class OpenAI implements AiModel {
return Mono.error(error);
});
}
-
+
@Override
public Long getCompletionTokens(final String responseBody) {
try {
@@ -101,6 +102,11 @@ public class OpenAI implements AiModel {
Map<String, Object> requestBodyMap =
GsonUtils.getInstance().convertToMap(originalBody);
requestBodyMap.put(Constants.MODEL, aiCommonConfig.getModel());
requestBodyMap.put(Constants.STREAM, aiCommonConfig.getStream());
+ if (aiCommonConfig.getStream()) {
+ Map<String, Object> streamOptions = new HashMap<>();
+ streamOptions.put(Constants.INCLUDE_USAGE, true);
+ requestBodyMap.put(Constants.STREAM_OPTIONS, streamOptions);
+ }
return GsonUtils.getInstance().toJson(requestBodyMap);
}
diff --git
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
index c7f66d2ae9..0765ed9695 100644
---
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
+++
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
@@ -17,6 +17,8 @@
package org.apache.shenyu.plugin.ai.token.limiter;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang3.StringUtils;
import org.apache.shenyu.common.constant.Constants;
import org.apache.shenyu.common.dto.RuleData;
@@ -24,7 +26,6 @@ import org.apache.shenyu.common.dto.SelectorData;
import org.apache.shenyu.common.dto.convert.rule.AiTokenLimiterHandle;
import org.apache.shenyu.common.enums.AiTokenLimiterEnum;
import org.apache.shenyu.common.enums.PluginEnum;
-import org.apache.shenyu.plugin.ai.common.strategy.AiModel;
import
org.apache.shenyu.plugin.ai.token.limiter.handler.AiTokenLimiterPluginHandler;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.result.ShenyuResultEnum;
@@ -61,6 +62,8 @@ import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
import java.util.zip.DataFormatException;
import java.util.zip.Inflater;
@@ -68,38 +71,40 @@ import java.util.zip.Inflater;
* Shenyu ai token limiter plugin.
*/
public class AiTokenLimiterPlugin extends AbstractShenyuPlugin {
-
+
private static final Logger LOG =
LoggerFactory.getLogger(AiTokenLimiterPlugin.class);
-
+
private static final String REDIS_KEY_PREFIX = "SHENYU:AI:TOKENLIMIT:";
-
+
+ private static final Pattern COMPLETION_TOKENS_PATTERN =
Pattern.compile("\"completion_tokens\"\\s*:\\s*(\\d+)");
+
@Override
protected Mono<Void> doExecute(final ServerWebExchange exchange, final
ShenyuPluginChain chain,
final SelectorData selector, final RuleData
rule) {
-
+
AiTokenLimiterHandle aiTokenLimiterHandle =
AiTokenLimiterPluginHandler.CACHED_HANDLE.get().obtainHandle(CacheKeyUtils.INST.getKey(rule));
-
+
if (Objects.isNull(aiTokenLimiterHandle)) {
return chain.execute(exchange);
}
-
+
ReactiveRedisTemplate reactiveRedisTemplate =
AiTokenLimiterPluginHandler.REDIS_CACHED_HANDLE.get().obtainHandle(PluginEnum.AI_TOKEN_LIMITER.getName());
Assert.notNull(reactiveRedisTemplate, "reactiveRedisTemplate is null");
-
+
// generate redis key
String tokenLimitType = aiTokenLimiterHandle.getAiTokenLimitType();
String keyName = aiTokenLimiterHandle.getKeyName();
Long tokenLimit = aiTokenLimiterHandle.getTokenLimit();
Long timeWindowSeconds = aiTokenLimiterHandle.getTimeWindowSeconds();
-
+
String cacheKey = REDIS_KEY_PREFIX + getCacheKey(exchange,
tokenLimitType, keyName);
-
+
final AiStatisticServerHttpResponse loggingServerHttpResponse = new
AiStatisticServerHttpResponse(exchange, exchange.getResponse(),
tokens -> recordTokensUsage(reactiveRedisTemplate,
cacheKey,
tokens,
timeWindowSeconds));
-
+
// check if the request is allowed
return isAllowed(reactiveRedisTemplate, cacheKey, tokenLimit)
.flatMap(allowed -> {
@@ -114,12 +119,12 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
ServerWebExchange mutatedExchange = exchange.mutate()
.response(loggingServerHttpResponse)
.build();
-
+
return chain.execute(mutatedExchange);
});
-
+
}
-
+
/**
* Check if the request is allowed based on rate limiting rules.
*
@@ -129,7 +134,7 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
* @return whether the request is allowed
*/
private Mono<Boolean> isAllowed(final ReactiveRedisTemplate
reactiveRedisTemplate, final String cacheKey, final Long tokenLimit) {
-
+
return reactiveRedisTemplate.opsForValue().get(cacheKey)
.defaultIfEmpty(0L)
.flatMap(currentTokens -> {
@@ -139,7 +144,7 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
return Mono.just(true);
});
}
-
+
/**
* Get the cache key based on the configured key resolver type.
*
@@ -165,7 +170,7 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
}
default -> exchange.getAttribute(Constants.CONTEXT_PATH);
};
-
+
return StringUtils.isBlank(key) ? "" : key;
}
@@ -176,38 +181,49 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
.flatMap(currentValue ->
reactiveRedisTemplate.expire(cacheKey, Duration.ofSeconds(windowSeconds)))
.subscribe();
}
-
+
@Override
public int getOrder() {
return PluginEnum.AI_TOKEN_LIMITER.getCode();
}
-
+
@Override
public String named() {
return PluginEnum.AI_TOKEN_LIMITER.getName();
}
-
+
static class AiStatisticServerHttpResponse extends
ServerHttpResponseDecorator {
-
+ private static final ObjectMapper MAPPER = new ObjectMapper();
+
private final ServerWebExchange exchange;
-
+
private final ServerHttpResponse serverHttpResponse;
-
+
private final Consumer<Long> tokensRecorder;
-
+
+ private final AtomicBoolean streamingUsageRecorded = new
AtomicBoolean(false);
+
AiStatisticServerHttpResponse(final ServerWebExchange exchange, final
ServerHttpResponse delegate, final Consumer<Long> tokensRecorder) {
super(delegate);
this.exchange = exchange;
this.serverHttpResponse = delegate;
this.tokensRecorder = tokensRecorder;
}
-
+
@Override
@NonNull
public Mono<Void> writeWith(@NonNull final Publisher<? extends
DataBuffer> body) {
return super.writeWith(appendResponse(body));
}
+ @Override
+ @NonNull
+ public Mono<Void> writeAndFlushWith(@NonNull final Publisher<? extends
Publisher<? extends DataBuffer>> body) {
+ Flux<? extends Publisher<? extends DataBuffer>> intercepted =
Flux.from(body)
+ .map(this::appendResponse);
+ return super.writeAndFlushWith(intercepted);
+ }
+
@NonNull
private Flux<? extends DataBuffer> appendResponse(final Publisher<?
extends DataBuffer> body) {
BodyWriter writer = new BodyWriter();
@@ -228,42 +244,77 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
byte[] inBytes = new byte[ro.remaining()];
ro.get(inBytes);
+ byte[] processedBytes;
if (isGzip) {
int offset = 0;
- int len = inBytes.length;
- if (!headerSkipped.get()) {
+ if (headerSkipped.compareAndSet(false,
true)) {
offset = skipGzipHeader(inBytes);
- headerSkipped.set(true);
}
- inflater.setInput(inBytes, offset, len -
offset);
+ inflater.setInput(inBytes, offset,
inBytes.length - offset);
+ ByteArrayOutputStream baos = new
ByteArrayOutputStream();
try {
int cnt;
while ((cnt =
inflater.inflate(outBuf)) > 0) {
-
writer.write(ByteBuffer.wrap(outBuf, 0, cnt));
+ baos.write(outBuf, 0, cnt);
}
} catch (DataFormatException ex) {
- LOG.error("inflater decompression
failed", ex);
+ LOG.error("Inflater decompression
failed", ex);
}
+ processedBytes = baos.toByteArray();
} else {
- writer.write(ro);
+ processedBytes = inBytes;
+ }
+ String chunk = new String(processedBytes,
StandardCharsets.UTF_8);
+ for (String line : chunk.split("\\r?\\n")) {
+ if (!line.startsWith("data:")) {
+ continue;
+ }
+ String payload =
line.substring("data:".length()).trim();
+ if (payload.isEmpty() ||
"[DONE]".equals(payload)) {
+ continue;
+ }
+ if (!payload.startsWith("{")) {
+ continue;
+ }
+ try {
+ JsonNode node =
MAPPER.readTree(payload);
+ JsonNode usage =
node.get(Constants.USAGE);
+ if (Objects.nonNull(usage) &&
usage.has(Constants.COMPLETION_TOKENS)) {
+ long c =
usage.get(Constants.COMPLETION_TOKENS).asLong();
+ tokensRecorder.accept(c);
+ streamingUsageRecorded.set(true);
+ }
+ } catch (Exception e) {
+ LOG.error("Failed to parse AI response
JSON payload", e);
+ }
}
+ writer.write(ByteBuffer.wrap(processedBytes));
});
} catch (Exception e) {
LOG.error("read dataBuffer error", e);
}
})
.doFinally(signal -> {
- // release inflater
if (Objects.nonNull(inflater)) {
inflater.end();
}
- String responseBody = writer.output();
- AiModel aiModel =
exchange.getAttribute(Constants.AI_MODEL);
- long tokens =
Objects.requireNonNull(aiModel).getCompletionTokens(responseBody);
- tokensRecorder.accept(tokens);
+ if (!streamingUsageRecorded.get()) {
+ String sse = writer.output();
+ long usageTokens = extractUsageTokensFromSse(sse);
+ tokensRecorder.accept(usageTokens);
+ }
});
}
+ private long extractUsageTokensFromSse(final String sse) {
+ Matcher m = COMPLETION_TOKENS_PATTERN.matcher(sse);
+ long last = 0L;
+ while (m.find()) {
+ last = Long.parseLong(m.group(1));
+ }
+ return last;
+ }
+
private int skipGzipHeader(final byte[] b) {
int pos = 10;
int flg = b[3] & 0xFF;
@@ -294,15 +345,15 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
}
}
-
+
static class BodyWriter {
-
+
private final ByteArrayOutputStream stream = new
ByteArrayOutputStream();
-
+
private final WritableByteChannel channel =
Channels.newChannel(stream);
-
+
private final AtomicBoolean isClosed = new AtomicBoolean(false);
-
+
void write(final ByteBuffer buffer) {
if (!isClosed.get()) {
try {
@@ -313,11 +364,11 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
}
}
}
-
+
boolean isEmpty() {
return stream.size() == 0;
}
-
+
String output() {
try {
isClosed.compareAndSet(false, true);