This is an automated email from the ASF dual-hosted git repository.

shreemaanabhishek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git


The following commit(s) were added to refs/heads/master by this push:
     new 52d8fea3d feat: ai-rate-limiting plugin (#12037)
52d8fea3d is described below

commit 52d8fea3dbc93c4b59ac6682902955c6146af4ad
Author: Shreemaan Abhishek <[email protected]>
AuthorDate: Thu Mar 13 13:56:06 2025 +0545

    feat: ai-rate-limiting plugin (#12037)
---
 apisix/cli/config.lua                      |   3 +
 apisix/cli/ngx_tpl.lua                     |  14 +-
 apisix/plugins/ai-rate-limiting.lua        | 209 +++++++++++
 apisix/plugins/limit-count/init.lua        |  30 +-
 conf/config.yaml.example                   |   1 +
 docs/en/latest/config.json                 |   3 +-
 docs/en/latest/plugins/ai-rate-limiting.md | 117 +++++++
 t/APISIX.pm                                |   2 +
 t/admin/plugins.t                          |   1 +
 t/plugin/ai-rate-limiting.t                | 539 +++++++++++++++++++++++++++++
 10 files changed, 907 insertions(+), 12 deletions(-)

diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua
index 56af978c2..8cca713e0 100644
--- a/apisix/cli/config.lua
+++ b/apisix/cli/config.lua
@@ -160,6 +160,8 @@ local _M = {
         ["plugin-limit-req-redis-cluster-slot-lock"] = "1m",
         ["plugin-limit-count-redis-cluster-slot-lock"] = "1m",
         ["plugin-limit-conn-redis-cluster-slot-lock"] = "1m",
+        ["plugin-ai-rate-limiting"] = "10m",
+        ["plugin-ai-rate-limiting-reset-header"] = "10m",
         tracing_buffer = "10m",
         ["plugin-api-breaker"] = "10m",
         ["etcd-cluster-health-check"] = "10m",
@@ -219,6 +221,7 @@ local _M = {
     "ai-prompt-decorator",
     "ai-prompt-guard",
     "ai-rag",
+    "ai-rate-limiting",
     "ai-proxy-multi",
     "ai-proxy",
     "ai-aws-content-moderation",
diff --git a/apisix/cli/ngx_tpl.lua b/apisix/cli/ngx_tpl.lua
index 4b7ff4102..dec8f7172 100644
--- a/apisix/cli/ngx_tpl.lua
+++ b/apisix/cli/ngx_tpl.lua
@@ -287,6 +287,19 @@ http {
     lua_shared_dict tars {* http.lua_shared_dict["tars"] *};
     {% end %}
 
+
+    {% if http.lua_shared_dict["plugin-ai-rate-limiting"] then %}
+    lua_shared_dict plugin-ai-rate-limiting {* 
http.lua_shared_dict["plugin-ai-rate-limiting"] *};
+    {% else %}
+    lua_shared_dict plugin-ai-rate-limiting 10m;
+    {% end %}
+
+    {% if http.lua_shared_dict["plugin-ai-rate-limiting"] then %}
+    lua_shared_dict plugin-ai-rate-limiting-reset-header {* 
http.lua_shared_dict["plugin-ai-rate-limiting-reset-header"] *};
+    {% else %}
+    lua_shared_dict plugin-ai-rate-limiting-reset-header 10m;
+    {% end %}
+
     {% if enabled_plugins["limit-conn"] then %}
     lua_shared_dict plugin-limit-conn {* 
http.lua_shared_dict["plugin-limit-conn"] *};
     lua_shared_dict plugin-limit-conn-redis-cluster-slot-lock {* 
http.lua_shared_dict["plugin-limit-conn-redis-cluster-slot-lock"] *};
@@ -418,7 +431,6 @@ http {
     {% if ssl.ssl_trusted_certificate ~= nil then %}
     lua_ssl_trusted_certificate {* ssl.ssl_trusted_certificate *};
     {% end %}
-
     # http configuration snippet starts
     {% if http_configuration_snippet then %}
     {* http_configuration_snippet *}
diff --git a/apisix/plugins/ai-rate-limiting.lua 
b/apisix/plugins/ai-rate-limiting.lua
new file mode 100644
index 000000000..374c03997
--- /dev/null
+++ b/apisix/plugins/ai-rate-limiting.lua
@@ -0,0 +1,209 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements.  See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License.  You may obtain a copy of the License at
+--
+--     http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+local require = require
+local setmetatable = setmetatable
+local ipairs = ipairs
+local type = type
+local core = require("apisix.core")
+local limit_count = require("apisix.plugins.limit-count.init")
+
+local plugin_name = "ai-rate-limiting"
+
+local instance_limit_schema = {
+    type = "object",
+    properties = {
+        name = {type = "string"},
+        limit = {type = "integer", minimum = 1},
+        time_window = {type = "integer", minimum = 1}
+    },
+    required = {"name", "limit", "time_window"}
+}
+
+local schema = {
+    type = "object",
+    properties = {
+        limit = {type = "integer", exclusiveMinimum = 0},
+        time_window = {type = "integer",  exclusiveMinimum = 0},
+        show_limit_quota_header = {type = "boolean", default = true},
+        limit_strategy = {
+            type = "string",
+            enum = {"total_tokens", "prompt_tokens", "completion_tokens"},
+            default = "total_tokens",
+            description = "The strategy to limit the tokens"
+        },
+        instances = {
+            type = "array",
+            items = instance_limit_schema
+        },
+        rejected_code = {
+            type = "integer", minimum = 200, maximum = 599, default = 503
+        },
+        rejected_msg = {
+            type = "string", minLength = 1
+        },
+    },
+    required = {"limit", "time_window"},
+}
+
+local _M = {
+    version = 0.1,
+    priority = 1030,
+    name = plugin_name,
+    schema = schema
+}
+
+local limit_conf_cache = core.lrucache.new({
+    ttl = 300, count = 512
+})
+
+
+function _M.check_schema(conf)
+    return core.schema.check(schema, conf)
+end
+
+
+local function transform_limit_conf(plugin_conf, instance_conf, instance_name)
+    local key = plugin_name .. "#global"
+    local limit = plugin_conf.limit
+    local time_window = plugin_conf.time_window
+    local name = instance_name or ""
+    if instance_conf then
+        name = instance_conf.name
+        key = instance_conf.name
+        limit = instance_conf.limit
+        time_window = instance_conf.time_window
+    end
+    return {
+        _vid = key,
+
+        key = key,
+        count = limit,
+        time_window = time_window,
+        rejected_code = plugin_conf.rejected_code,
+        rejected_msg = plugin_conf.rejected_msg,
+        show_limit_quota_header = plugin_conf.show_limit_quota_header,
+        -- limit-count need these fields
+        policy = "local",
+        key_type = "constant",
+        allow_degradation = false,
+        sync_interval = -1,
+
+        limit_header = "X-AI-RateLimit-Limit-" .. name,
+        remaining_header = "X-AI-RateLimit-Remaining-" .. name,
+        reset_header = "X-AI-RateLimit-Reset-" .. name,
+    }
+end
+
+
+local function fetch_limit_conf_kvs(conf)
+    local mt = {
+        __index = function(t, k)
+            local limit_conf = transform_limit_conf(conf, nil, k)
+            t[k] = limit_conf
+            return limit_conf
+        end
+    }
+    local limit_conf_kvs = setmetatable({}, mt)
+    local conf_instances = conf.instances or {}
+    for _, limit_conf in ipairs(conf_instances) do
+        limit_conf_kvs[limit_conf.name] = transform_limit_conf(conf, 
limit_conf)
+    end
+    return limit_conf_kvs
+end
+
+
+function _M.access(conf, ctx)
+    local ai_instance_name = ctx.picked_ai_instance_name
+    if not ai_instance_name then
+        return
+    end
+
+    local limit_conf_kvs = limit_conf_cache(conf, nil, fetch_limit_conf_kvs, 
conf)
+    local limit_conf = limit_conf_kvs[ai_instance_name]
+    local code, msg = limit_count.rate_limit(limit_conf, ctx, plugin_name, 1, 
true)
+    ctx.ai_rate_limiting = code and true or false
+    return code, msg
+end
+
+
+function _M.check_instance_status(conf, ctx, instance_name)
+    if conf == nil then
+        local plugins = ctx.plugins
+        for _, plugin in ipairs(plugins) do
+            if plugin.name == plugin_name then
+                conf = plugin
+            end
+        end
+    end
+    if not conf then
+        return true
+    end
+
+    instance_name = instance_name or ctx.picked_ai_instance_name
+    if not instance_name then
+        return nil, "missing instance_name"
+    end
+
+    if type(instance_name) ~= "string" then
+        return nil, "invalid instance_name"
+    end
+
+    local limit_conf_kvs = limit_conf_cache(conf, nil, fetch_limit_conf_kvs, 
conf)
+    local limit_conf = limit_conf_kvs[instance_name]
+    local code, _ = limit_count.rate_limit(limit_conf, ctx, plugin_name, 1, 
true)
+    if code then
+        core.log.info("rate limit for instance: ", instance_name, " code: ", 
code)
+        return false
+    end
+    return true
+end
+
+
+local function get_token_usage(conf, ctx)
+    local usage = ctx.ai_token_usage
+    if not usage then
+        return
+    end
+    return usage[conf.limit_strategy]
+end
+
+
+function _M.log(conf, ctx)
+    local instance_name = ctx.picked_ai_instance_name
+    if not instance_name then
+        return
+    end
+
+    if ctx.ai_rate_limiting then
+        return
+    end
+
+    local used_tokens = get_token_usage(conf, ctx)
+    if not used_tokens then
+        core.log.error("failed to get token usage for llm service")
+        return
+    end
+
+    core.log.info("instance name: ", instance_name, " used tokens: ", 
used_tokens)
+
+    local limit_conf_kvs = limit_conf_cache(conf, nil, fetch_limit_conf_kvs, 
conf)
+    local limit_conf = limit_conf_kvs[instance_name]
+    limit_count.rate_limit(limit_conf, ctx, plugin_name, used_tokens)
+end
+
+
+return _M
diff --git a/apisix/plugins/limit-count/init.lua 
b/apisix/plugins/limit-count/init.lua
index 08a4c9763..1f37965c4 100644
--- a/apisix/plugins/limit-count/init.lua
+++ b/apisix/plugins/limit-count/init.lua
@@ -21,6 +21,7 @@ local ipairs = ipairs
 local pairs = pairs
 local redis_schema = require("apisix.utils.redis-schema")
 local policy_to_additional_properties = redis_schema.schema
+local get_phase = ngx.get_phase
 
 local limit_redis_cluster_new
 local limit_redis_new
@@ -233,8 +234,9 @@ local function gen_limit_obj(conf, ctx, plugin_name)
     return core.lrucache.plugin_ctx(lrucache, ctx, extra_key, 
create_limit_obj, conf, plugin_name)
 end
 
-function _M.rate_limit(conf, ctx, name, cost)
+function _M.rate_limit(conf, ctx, name, cost, dry_run)
     core.log.info("ver: ", ctx.conf_version)
+    core.log.info("conf: ", core.json.delay_encode(conf, true))
 
     local lim, err = gen_limit_obj(conf, ctx, name)
 
@@ -275,7 +277,7 @@ function _M.rate_limit(conf, ctx, name, cost)
 
     local delay, remaining, reset
     if not conf.policy or conf.policy == "local" then
-        delay, remaining, reset = lim:incoming(key, true, conf, cost)
+        delay, remaining, reset = lim:incoming(key, not dry_run, conf, cost)
     else
         delay, remaining, reset = lim:incoming(key, cost)
     end
@@ -288,14 +290,22 @@ function _M.rate_limit(conf, ctx, name, cost)
     end
     core.log.info("limit-count plugin-metadata: ", 
core.json.delay_encode(metadata))
 
+    local set_limit_headers = {
+        limit_header = conf.limit_header or metadata.limit_header,
+        remaining_header = conf.remaining_header or metadata.remaining_header,
+        reset_header = conf.reset_header or metadata.reset_header,
+    }
+    local phase = get_phase()
+    local set_header = phase ~= "log"
+
     if not delay then
         local err = remaining
         if err == "rejected" then
             -- show count limit header when rejected
-            if conf.show_limit_quota_header then
-                core.response.set_header(metadata.limit_header, conf.count,
-                    metadata.remaining_header, 0,
-                    metadata.reset_header, reset)
+            if conf.show_limit_quota_header and set_header then
+                core.response.set_header(set_limit_headers.limit_header, 
conf.count,
+                set_limit_headers.remaining_header, 0,
+                set_limit_headers.reset_header, reset)
             end
 
             if conf.rejected_msg then
@@ -311,10 +321,10 @@ function _M.rate_limit(conf, ctx, name, cost)
         return 500, {error_msg = "failed to limit count"}
     end
 
-    if conf.show_limit_quota_header then
-        core.response.set_header(metadata.limit_header, conf.count,
-            metadata.remaining_header, remaining,
-            metadata.reset_header, reset)
+    if conf.show_limit_quota_header and set_header then
+        core.response.set_header(set_limit_headers.limit_header, conf.count,
+            set_limit_headers.remaining_header, remaining,
+            set_limit_headers.reset_header, reset)
     end
 end
 
diff --git a/conf/config.yaml.example b/conf/config.yaml.example
index c0da9c0bf..22035f8da 100644
--- a/conf/config.yaml.example
+++ b/conf/config.yaml.example
@@ -480,6 +480,7 @@ plugins:                           # plugin list (sorted by 
priority)
   - ai-prompt-decorator            # priority: 1070
   - ai-prompt-guard                # priority: 1072
   - ai-rag                         # priority: 1060
+  - ai-rate-limiting               # priority: 1030
   - ai-aws-content-moderation      # priority: 1040 TODO: compare priority 
with other ai plugins
   - proxy-mirror                   # priority: 1010
   - proxy-rewrite                  # priority: 1008
diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json
index b540df5f4..7074fbfc2 100644
--- a/docs/en/latest/config.json
+++ b/docs/en/latest/config.json
@@ -158,7 +158,8 @@
             "plugins/request-id",
             "plugins/proxy-control",
             "plugins/client-control",
-            "plugins/workflow"
+            "plugins/workflow",
+            "plugins/ai-rate-limiting"
           ]
         },
         {
diff --git a/docs/en/latest/plugins/ai-rate-limiting.md 
b/docs/en/latest/plugins/ai-rate-limiting.md
new file mode 100644
index 000000000..839818153
--- /dev/null
+++ b/docs/en/latest/plugins/ai-rate-limiting.md
@@ -0,0 +1,117 @@
+---
+title: AI Rate Limiting
+keywords:
+  - Apache APISIX
+  - API Gateway
+  - Plugin
+  - ai-rate-limiting
+description: The ai-rate-limiting plugin enforces token-based rate limiting 
for LLM service requests, preventing overuse, optimizing API consumption, and 
ensuring efficient resource allocation.
+---
+
+<!--
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+-->
+
+## Description
+
+The `ai-rate-limiting` plugin enforces token-based rate limiting for requests 
sent to LLM services. It helps manage API usage by controlling the number of 
tokens consumed within a specified time frame, ensuring fair resource 
allocation and preventing excessive load on the service. It is often used with 
`ai-proxy` or `ai-proxy-multi` plugin.
+
+## Plugin Attributes
+
+| Name                      | Type          | Required | Description           
                                                                                
                                                                                
                                                                                
                        |
+| ------------------------- | ------------- | -------- | 
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 |
+| `limit`                   | integer       | false    | The maximum number of 
tokens allowed to consume within a given time interval. At least one of `limit` 
and `instances.limit` should be configured.                                     
                                                                                
                        |
+| `time_window`             | integer       | false    | The time interval 
corresponding to the rate limiting `limit` in seconds. At least one of 
`time_window` and `instances.time_window` should be configured.                 
                                                                                
                                     |
+| `show_limit_quota_header` | boolean       | false    | If true, include 
`X-AI-RateLimit-Limit-*` to show the total quota, `X-AI-RateLimit-Remaining-*` 
to show the remaining quota in the response header, and 
`X-AI-RateLimit-Reset-*` to show the number of seconds left for the counter to 
reset, where `*` is the instance name. Default: `true` |
+| `limit_strategy`          | string        | false    | Type of token to 
apply rate limiting. `total_tokens`, `prompt_tokens`, and `completion_tokens` 
values are returned in each model response, where `total_tokens` is the sum of 
`prompt_tokens` and `completion_tokens`. Default: `total_tokens`                
                                |
+| `instances`               | array[object] | false    | LLM instance rate 
limiting configurations.                                                        
                                                                                
                                                                                
                            |
+| `instances.name`          | string        | true     | Name of the LLM 
service instance.                                                               
                                                                                
                                                                                
                              |
+| `instances.limit`         | integer       | true     | The maximum number of 
tokens allowed to consume within a given time interval.                         
                                                                                
                                                                                
                        |
+| `instances.time_window`   | integer       | true     | The time interval 
corresponding to the rate limiting `limit` in seconds.                          
                                                                                
                                                                                
                            |
+| `rejected_code`           | integer       | false    | The HTTP status code 
returned when a request exceeding the quota is rejected. Default: `503`         
                                                                                
                                                                                
                         |
+| `rejected_msg`            | string        | false    | The response body 
returned when a request exceeding the quota is rejected.                        
                                                                                
                                                                                
                            |
+
+## Example
+
+Create a route as such and update with your LLM providers, models, API keys, 
and endpoints:
+
+```shell
+curl "http://127.0.0.1:9180/apisix/admin/routes"; -X PUT \
+  -H "X-API-KEY: ${ADMIN_API_KEY}" \
+  -d '{
+    "id": "ai-rate-limiting-route",
+    "uri": "/anything",
+    "methods": ["POST"],
+    "plugins": {
+      "ai-proxy": {
+        "provider": "openai",
+        "auth": {
+          "header": {
+            "Authorization": "Bearer '"$API_KEY"'"
+          }
+        },
+        "options": {
+          "model": "gpt-35-turbo-instruct",
+          "max_tokens": 512,
+          "temperature": 1.0
+        }
+      },
+      "ai-rate-limiting": {
+        "limit": 300,
+        "time_window": 30,
+        "limit_strategy": "prompt_tokens"
+      }
+    }
+  }'
+```
+
+Send a POST request to the route with a system prompt and a sample user 
question in the request body:
+
+```shell
+curl "http://127.0.0.1:9080/anything"; -X POST \
+  -H "Content-Type: application/json" \
+  -d '{
+    "messages": [
+      { "role": "system", "content": "You are a mathematician" },
+      { "role": "user", "content": "What is 1+1?" }
+    ]
+  }'
+```
+
+You should receive a response similar to the following:
+
+```json
+{
+  ...
+  "model": "deepseek-chat",
+  "choices": [
+    {
+      "index": 0,
+      "message": {
+        "role": "assistant",
+        "content": "1 + 1 equals 2. This is a fundamental arithmetic operation 
where adding one unit to another results in a total of two units."
+      },
+      "logprobs": null,
+      "finish_reason": "stop"
+    }
+  ],
+  ...
+}
+```
+
+If rate limiting quota of 300 tokens has been consumed in a 30-second window, 
the additional requests will all be rejected.
diff --git a/t/APISIX.pm b/t/APISIX.pm
index 2e1724a12..f4b9b8055 100644
--- a/t/APISIX.pm
+++ b/t/APISIX.pm
@@ -558,6 +558,8 @@ _EOC_
     lua_shared_dict plugin-limit-count 10m;
     lua_shared_dict plugin-limit-count-reset-header 10m;
     lua_shared_dict plugin-limit-conn 10m;
+    lua_shared_dict plugin-ai-rate-limiting 10m;
+    lua_shared_dict plugin-ai-rate-limiting-reset-header 10m;
     lua_shared_dict internal-status 10m;
     lua_shared_dict upstream-healthcheck 32m;
     lua_shared_dict worker-events 10m;
diff --git a/t/admin/plugins.t b/t/admin/plugins.t
index c43d5ffeb..2a759bbe3 100644
--- a/t/admin/plugins.t
+++ b/t/admin/plugins.t
@@ -101,6 +101,7 @@ ai-rag
 ai-aws-content-moderation
 ai-proxy-multi
 ai-proxy
+ai-rate-limiting
 proxy-mirror
 proxy-rewrite
 workflow
diff --git a/t/plugin/ai-rate-limiting.t b/t/plugin/ai-rate-limiting.t
new file mode 100644
index 000000000..a6396acf0
--- /dev/null
+++ b/t/plugin/ai-rate-limiting.t
@@ -0,0 +1,539 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_root_location();
+
+
+my $resp_file = 't/assets/ai-proxy-response.json';
+open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!";
+my $resp = do { local $/; <$fh> };
+close($fh);
+
+print "Hello, World!\n";
+print $resp;
+
+
+add_block_preprocessor(sub {
+    my ($block) = @_;
+
+    if (!defined $block->request) {
+        $block->set_value("request", "GET /t");
+    }
+
+    my $http_config = $block->http_config // <<_EOC_;
+        server {
+            server_name openai;
+            listen 16724;
+
+            default_type 'application/json';
+
+            location /anything {
+                content_by_lua_block {
+                    local json = require("cjson.safe")
+
+                    if ngx.req.get_method() ~= "POST" then
+                        ngx.status = 400
+                        ngx.say("Unsupported request method: ", 
ngx.req.get_method())
+                    end
+                    ngx.req.read_body()
+                    local body = ngx.req.get_body_data()
+
+                    if body ~= "SELECT * FROM STUDENTS" then
+                        ngx.status = 503
+                        ngx.say("passthrough doesn't work")
+                        return
+                    end
+                    ngx.say('{"foo", "bar"}')
+                }
+            }
+
+            location /v1/chat/completions {
+                content_by_lua_block {
+                    local json = require("cjson.safe")
+
+                    if ngx.req.get_method() ~= "POST" then
+                        ngx.status = 400
+                        ngx.say("Unsupported request method: ", 
ngx.req.get_method())
+                    end
+                    ngx.req.read_body()
+                    local body, err = ngx.req.get_body_data()
+                    body, err = json.decode(body)
+
+                    local test_type = ngx.req.get_headers()["test-type"]
+                    if test_type == "options" then
+                        if body.foo == "bar" then
+                            ngx.status = 200
+                            ngx.say("options works")
+                        else
+                            ngx.status = 500
+                            ngx.say("model options feature doesn't work")
+                        end
+                        return
+                    end
+
+                    local header_auth = ngx.req.get_headers()["authorization"]
+                    local query_auth = ngx.req.get_uri_args()["apikey"]
+
+                    if header_auth ~= "Bearer token" and query_auth ~= 
"apikey" then
+                        ngx.status = 401
+                        ngx.say("Unauthorized")
+                        return
+                    end
+
+                    if header_auth == "Bearer token" or query_auth == "apikey" 
then
+                        ngx.req.read_body()
+                        local body, err = ngx.req.get_body_data()
+                        body, err = json.decode(body)
+
+                        if not body.messages or #body.messages < 1 then
+                            ngx.status = 400
+                            ngx.say([[{ "error": "bad request"}]])
+                            return
+                        end
+
+                        if body.messages[1].content == "write an SQL query to 
get all rows from student table" then
+                            ngx.print("SELECT * FROM STUDENTS")
+                            return
+                        end
+
+                        ngx.status = 200
+                        ngx.say([[
+{
+  "choices": [
+    {
+      "finish_reason": "stop",
+      "index": 0,
+      "message": { "content": "1 + 1 = 2.", "role": "assistant" }
+    }
+  ],
+  "created": 1723780938,
+  "id": "chatcmpl-9wiSIg5LYrrpxwsr2PubSQnbtod1P",
+  "model": "gpt-4o-2024-05-13",
+  "object": "chat.completion",
+  "system_fingerprint": "fp_abc28019ad",
+  "usage": { "completion_tokens": 5, "prompt_tokens": 8, "total_tokens": 10 }
+}
+                        ]])
+                        return
+                    end
+
+
+                    ngx.status = 503
+                    ngx.say("reached the end of the test suite")
+                }
+            }
+
+            location /random {
+                content_by_lua_block {
+                    ngx.say("path override works")
+                }
+            }
+        }
+_EOC_
+
+    $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: sanity
+--- config
+    location /t {
+        content_by_lua_block {
+            local configs = {
+                {
+                    time_window = 60,
+                },
+                {
+                    limit = 30,
+                },
+                {
+                    limit = 30,
+                    time_window = 60,
+                    rejected_code = 199,
+                },
+                {
+                    limit = 30,
+                    time_window = 60,
+                    limit_strategy = "invalid",
+                },
+                {
+                    limit = 30,
+                    time_window = 60,
+                    instances = {
+                        {
+                            name = "instance1",
+                            limit = 30,
+                            time_window = 60,
+                        },
+                        {
+                            limit = 30,
+                            time_window = 60,
+                        }
+                    },
+                },
+                {
+                    time_window = 60,
+                    instances = {
+                        {
+                            name = "instance1",
+                            limit = 30,
+                            time_window = 60,
+                        }
+                    },
+                },
+                {
+                    limit = 30,
+                    time_window = 60,
+                    rejected_code = 403,
+                    rejected_msg = "rate limit exceeded",
+                    limit_strategy = "completion_tokens",
+                }
+            }
+            local core = require("apisix.core")
+            local plugin = require("apisix.plugins.ai-rate-limiting")
+            for _, config in ipairs(configs) do
+                local ok, err = plugin.check_schema(config)
+                if not ok then
+                    ngx.say(err)
+                else
+                    ngx.say("passed")
+                end
+            end
+            ngx.say("done")
+        }
+    }
+--- response_body
+property "limit" is required
+property "time_window" is required
+property "rejected_code" validation failed: expected 199 to be at least 200
+property "limit_strategy" validation failed: matches none of the enum values
+property "instances" validation failed: failed to validate item 2: property 
"name" is required
+property "limit" is required
+passed
+done
+
+
+
+=== TEST 2: set route 1, default limit_strategy: total_tokens
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/ai",
+                    "plugins": {
+                        "ai-proxy": {
+                            "provider": "openai",
+                            "auth": {
+                                "header": {
+                                    "Authorization": "Bearer token"
+                                }
+                            },
+                            "options": {
+                                "model": "gpt-35-turbo-instruct",
+                                "max_tokens": 512,
+                                "temperature": 1.0
+                            },
+                            "override": {
+                                "endpoint": "http://localhost:16724";
+                            },
+                            "ssl_verify": false
+                        },
+                        "ai-rate-limiting": {
+                            "limit": 30,
+                            "time_window": 60
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 3: reject the 3th request
+--- pipelined_requests eval
+[
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+]
+--- more_headers
+Authorization: Bearer token
+--- error_code eval
+[200, 200, 200, 503]
+
+
+
+=== TEST 4: set rejected_code to 403, rejected_msg to "rate limit exceeded"
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/ai",
+                    "plugins": {
+                        "ai-proxy": {
+                            "provider": "openai",
+                            "auth": {
+                                "header": {
+                                    "Authorization": "Bearer token"
+                                }
+                            },
+                            "options": {
+                                "model": "gpt-35-turbo-instruct",
+                                "max_tokens": 512,
+                                "temperature": 1.0
+                            },
+                            "override": {
+                                "endpoint": "http://localhost:16724";
+                            },
+                            "ssl_verify": false
+                        },
+                        "ai-rate-limiting": {
+                            "limit": 30,
+                            "time_window": 60,
+                            "rejected_code": 403,
+                            "rejected_msg": "rate limit exceeded"
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 5: check code and message
+--- pipelined_requests eval
+[
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+]
+--- more_headers
+Authorization: Bearer token
+--- error_code eval
+[200, 200, 200, 403]
+--- response_body eval
+[
+    qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/,
+    qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/,
+    qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/,
+    qr/\{"error_msg":"rate limit exceeded"\}/,
+]
+
+
+
+=== TEST 6: check rate limit headers
+--- request
+POST /ai
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { 
"role": "user", "content": "What is 1+1?"} ] }
+--- more_headers
+Authorization: Bearer token
+--- response_headers
+X-AI-RateLimit-Limit-ai-proxy: 30
+X-AI-RateLimit-Remaining-ai-proxy: 29
+X-AI-RateLimit-Reset-ai-proxy: 60
+
+
+
+=== TEST 7: check rate limit headers after 4 requests
+--- pipelined_requests eval
+[
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+]
+--- more_header
+Authorization: Bearer token
+--- error_code eval
+[200, 200, 200, 403]
+--- response_headers eval
+[
+    "X-AI-RateLimit-Remaining-ai-proxy: 29",
+    "X-AI-RateLimit-Remaining-ai-proxy: 19",
+    "X-AI-RateLimit-Remaining-ai-proxy: 9",
+    "X-AI-RateLimit-Remaining-ai-proxy: 0",
+]
+
+
+
+=== TEST 8: set route2 with limit_strategy: completion_tokens
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/2',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/ai2",
+                    "plugins": {
+                        "ai-proxy": {
+                            "provider": "openai",
+                            "auth": {
+                                "header": {
+                                    "Authorization": "Bearer token"
+                                }
+                            },
+                            "options": {
+                                "model": "gpt-35-turbo-instruct",
+                                "max_tokens": 512,
+                                "temperature": 1.0
+                            },
+                            "override": {
+                                "endpoint": "http://localhost:16724";
+                            },
+                            "ssl_verify": false
+                        },
+                        "ai-rate-limiting": {
+                            "limit": 20,
+                            "time_window": 45,
+                            "limit_strategy": "completion_tokens"
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 9: reject the 5th request
+--- pipelined_requests eval
+[
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+]
+--- more_headers
+Authorization: Bearer token
+--- error_code eval
+[200, 200, 200, 200, 503]
+
+
+
+=== TEST 10: check rate limit headers
+--- request
+POST /ai2
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { 
"role": "user", "content": "What is 1+1?"} ] }
+--- more_headers
+Authorization: Bearer token
+--- response_headers
+X-AI-RateLimit-Limit-ai-proxy: 20
+X-AI-RateLimit-Remaining-ai-proxy: 19
+X-AI-RateLimit-Reset-ai-proxy: 45
+
+
+
+=== TEST 11: multi-request
+--- pipelined_requests eval
+[
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+]
+--- more_header
+Authorization: Bearer token
+--- error_code eval
+[200, 200, 200, 200, 503]
+--- response_headers eval
+[
+    "X-AI-RateLimit-Remaining-ai-proxy: 19",
+    "X-AI-RateLimit-Remaining-ai-proxy: 14",
+    "X-AI-RateLimit-Remaining-ai-proxy: 9",
+    "X-AI-RateLimit-Remaining-ai-proxy: 4",
+    "X-AI-RateLimit-Remaining-ai-proxy: 0",
+]
+
+
+
+=== TEST 12: request route 1 and route 2
+--- pipelined_requests eval
+[
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+    "POST /ai2\n" . "{ \"messages\": [ { \"role\": \"system\", \"content\": 
\"You are a mathematician\" }, { \"role\": \"user\", \"content\": \"What is 
1+1?\"} ] }",
+]
+--- more_headers
+Authorization: Bearer token
+--- error_code eval
+[200, 200, 200, 200, 200, 200, 200, 403, 503]


Reply via email to