This is an automated email from the ASF dual-hosted git repository.
liuhongyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shenyu.git
The following commit(s) were added to refs/heads/master by this push:
new 6898a9282e fix AiTokenLimiterPlugin appendResponse (#6027)
6898a9282e is described below
commit 6898a9282e41dc69a21d0b889b7e0f5c872d6cc5
Author: HY-love-sleep <[email protected]>
AuthorDate: Tue May 20 09:43:38 2025 +0800
fix AiTokenLimiterPlugin appendResponse (#6027)
* fix: one-time decompression after the flow is finished
* chore: code style
* feat: save memory by streaming cross-block decompression
* chore: code style
* chore: del useless imports
---
.../ai/token/limiter/AiTokenLimiterPlugin.java | 156 ++++++++++++---------
1 file changed, 92 insertions(+), 64 deletions(-)
diff --git
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
index 8781885c7d..c7f66d2ae9 100644
---
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
+++
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
@@ -38,6 +38,7 @@ import org.slf4j.LoggerFactory;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.data.redis.core.ReactiveRedisTemplate;
import org.springframework.http.HttpCookie;
+import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.server.reactive.ServerHttpRequest;
@@ -49,7 +50,6 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.annotation.NonNull;
-import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
@@ -61,7 +61,8 @@ import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
-import java.util.zip.GZIPInputStream;
+import java.util.zip.DataFormatException;
+import java.util.zip.Inflater;
/**
* Shenyu ai token limiter plugin.
@@ -152,32 +153,22 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
String key;
// Determine the key based on the configured key resolver type
AiTokenLimiterEnum tokenLimiterEnum =
AiTokenLimiterEnum.getByName(tokenLimitType);
-
- switch (tokenLimiterEnum) {
- case IP:
- key =
Objects.requireNonNull(request.getRemoteAddress()).getHostString();
- break;
- case URI:
- key = request.getURI().getPath();
- break;
- case HEADER:
- key = request.getHeaders().getFirst(keyName);
- break;
- case PARAMETER:
- key = request.getQueryParams().getFirst(keyName);
- break;
- case COOKIE:
+
+ key = switch (tokenLimiterEnum) {
+ case IP ->
Objects.requireNonNull(request.getRemoteAddress()).getHostString();
+ case URI -> request.getURI().getPath();
+ case HEADER -> request.getHeaders().getFirst(keyName);
+ case PARAMETER -> request.getQueryParams().getFirst(keyName);
+ case COOKIE -> {
HttpCookie cookie = request.getCookies().getFirst(keyName);
- key = Objects.nonNull(cookie) ? cookie.getValue() : "";
- break;
- case CONTEXT_PATH:
- default:
- key = exchange.getAttribute(Constants.CONTEXT_PATH);
- }
+ yield Objects.nonNull(cookie) ? cookie.getValue() : "";
+ }
+ default -> exchange.getAttribute(Constants.CONTEXT_PATH);
+ };
return StringUtils.isBlank(key) ? "" : key;
}
-
+
private void recordTokensUsage(final ReactiveRedisTemplate
reactiveRedisTemplate, final String cacheKey, final Long tokens, final Long
windowSeconds) {
// Record token usage with expiration
reactiveRedisTemplate.opsForValue()
@@ -216,55 +207,92 @@ public class AiTokenLimiterPlugin extends
AbstractShenyuPlugin {
public Mono<Void> writeWith(@NonNull final Publisher<? extends
DataBuffer> body) {
return super.writeWith(appendResponse(body));
}
-
+
@NonNull
private Flux<? extends DataBuffer> appendResponse(final Publisher<?
extends DataBuffer> body) {
-
BodyWriter writer = new BodyWriter();
- return Flux.from(body).doOnNext(buffer -> {
- try (DataBuffer.ByteBufferIterator bufferIterator =
buffer.readableByteBuffers()) {
- bufferIterator.forEachRemaining(byteBuffer -> {
- // Handle gzip encoded response
- if
(serverHttpResponse.getHeaders().containsKey(Constants.CONTENT_ENCODING)
- &&
serverHttpResponse.getHeaders().getFirst(Constants.CONTENT_ENCODING).contains(Constants.HTTP_ACCEPT_ENCODING_GZIP))
{
- try {
- ByteBuffer readOnlyBuffer =
byteBuffer.asReadOnlyBuffer();
- byte[] compressed = new
byte[readOnlyBuffer.remaining()];
- readOnlyBuffer.get(compressed);
-
- // Decompress gzipped content
- byte[] decompressed =
decompressGzip(compressed);
- writer.write(ByteBuffer.wrap(decompressed));
-
- } catch (IOException e) {
- LOG.error("Failed to decompress gzipped
response", e);
- writer.write(byteBuffer.asReadOnlyBuffer());
- }
- } else {
- writer.write(byteBuffer.asReadOnlyBuffer());
+ HttpHeaders headers = serverHttpResponse.getHeaders();
+ boolean isGzip = headers.containsKey(Constants.CONTENT_ENCODING)
+ && headers.getFirst(Constants.CONTENT_ENCODING)
+ .contains(Constants.HTTP_ACCEPT_ENCODING_GZIP);
+
+ final Inflater inflater = isGzip ? new Inflater(true) : null;
+ final byte[] outBuf = new byte[4096];
+ final AtomicBoolean headerSkipped = new AtomicBoolean(!isGzip);
+
+ return Flux.<DataBuffer>from(body)
+ .doOnNext(buffer -> {
+ try (DataBuffer.ByteBufferIterator it =
buffer.readableByteBuffers()) {
+ it.forEachRemaining(bb -> {
+ ByteBuffer ro = bb.asReadOnlyBuffer();
+ byte[] inBytes = new byte[ro.remaining()];
+ ro.get(inBytes);
+
+ if (isGzip) {
+ int offset = 0;
+ int len = inBytes.length;
+ if (!headerSkipped.get()) {
+ offset = skipGzipHeader(inBytes);
+ headerSkipped.set(true);
+ }
+ inflater.setInput(inBytes, offset, len -
offset);
+ try {
+ int cnt;
+ while ((cnt =
inflater.inflate(outBuf)) > 0) {
+
writer.write(ByteBuffer.wrap(outBuf, 0, cnt));
+ }
+ } catch (DataFormatException ex) {
+ LOG.error("inflater decompression
failed", ex);
+ }
+ } else {
+ writer.write(ro);
+ }
+ });
+ } catch (Exception e) {
+ LOG.error("read dataBuffer error", e);
+ }
+ })
+ .doFinally(signal -> {
+ // release inflater
+ if (Objects.nonNull(inflater)) {
+ inflater.end();
}
+ String responseBody = writer.output();
+ AiModel aiModel =
exchange.getAttribute(Constants.AI_MODEL);
+ long tokens =
Objects.requireNonNull(aiModel).getCompletionTokens(responseBody);
+ tokensRecorder.accept(tokens);
});
- }
- }).doFinally(signal -> {
- String responseBody = writer.output();
- AiModel aiModel = exchange.getAttribute(Constants.AI_MODEL);
- long tokens =
Objects.requireNonNull(aiModel).getCompletionTokens(responseBody);
- tokensRecorder.accept(tokens);
- });
}
-
- private byte[] decompressGzip(final byte[] compressed) throws
IOException {
- try (GZIPInputStream gzipInputStream = new GZIPInputStream(new
ByteArrayInputStream(compressed));
- ByteArrayOutputStream outputStream = new
ByteArrayOutputStream()) {
- byte[] buffer = new byte[1024];
- int len;
- while ((len = gzipInputStream.read(buffer)) > 0) {
- outputStream.write(buffer, 0, len);
+
+ private int skipGzipHeader(final byte[] b) {
+ int pos = 10;
+ int flg = b[3] & 0xFF;
+
+ if ((flg & 0x04) != 0) {
+ int xlen = (b[pos] & 0xFF) | ((b[pos + 1] & 0xFF) << 8);
+ pos += 2 + xlen;
+ }
+
+ if ((flg & 0x08) != 0) {
+ while (b[pos] != 0) {
+ pos++;
+ }
+ pos++;
+ }
+
+ if ((flg & 0x10) != 0) {
+ while (b[pos] != 0) {
+ pos++;
}
- return outputStream.toByteArray();
+ pos++;
+ }
+
+ if ((flg & 0x02) != 0) {
+ pos += 2;
}
+ return pos;
}
-
+
}
static class BodyWriter {