nic-6443 commented on code in PR #13191:
URL: https://github.com/apache/apisix/pull/13191#discussion_r3057789724


##########
apisix/plugins/ai-rate-limiting.lua:
##########
@@ -136,8 +149,42 @@ local limit_conf_cache = core.lrucache.new({
 })
 
 
+-- safe math functions allowed in cost expressions
+local expr_safe_env = {
+    math = math,
+    abs = math.abs,
+    ceil = math.ceil,
+    floor = math.floor,
+    max = math.max,
+    min = math.min,
+}
+
+local function compile_cost_expr(expr_str)
+    local fn_code = "return " .. expr_str
+    -- validate syntax by loading first
+    local fn, err = load(fn_code, "cost_expr", "t", expr_safe_env)
+    if not fn then
+        return nil, err

Review Comment:
   The expression is configured by admins via the plugin config, not by 
end-users or API callers. Admins already have full control over the gateway. A 
strict arithmetic-only validator would add significant complexity. This is 
consistent with the upstream implementation.



##########
apisix/plugins/ai-rate-limiting.lua:
##########
@@ -264,7 +311,51 @@ function _M.check_instance_status(conf, ctx, instance_name)
 end
 
 
+local function eval_cost_expr(conf_cost_expr, raw)
+    local fn_code = "return " .. conf_cost_expr

Review Comment:
   Valid optimization opportunity. The expression strings are typically very 
short and load() in LuaJIT is fast. This is consistent with the upstream 
implementation. Caching the compiled function can be done as a follow-up 
optimization if profiling shows it matters.



##########
apisix/plugins/ai-rate-limiting.lua:
##########
@@ -264,7 +311,51 @@ function _M.check_instance_status(conf, ctx, instance_name)
 end
 
 
+local function eval_cost_expr(conf_cost_expr, raw)
+    local fn_code = "return " .. conf_cost_expr
+    -- build environment: safe math + usage variables (missing vars default to 
0)
+    local env = setmetatable({}, {
+        __index = function(_, k)
+            local v = expr_safe_env[k]
+            if v ~= nil then
+                return v
+            end
+            return 0
+        end
+    })
+    for k, v in pairs(raw) do
+        if type(v) == "number" then
+            env[k] = v
+        end
+    end
+    local fn, err = load(fn_code, "cost_expr", "t", env)
+    if not fn then
+        return nil, "failed to compile cost_expr: " .. err
+    end
+    local ok, result = pcall(fn)
+    if not ok then
+        return nil, "failed to evaluate cost_expr: " .. result
+    end
+    if type(result) ~= "number" then
+        return nil, "cost_expr must return a number, got: " .. type(result)
+    end
+    return math_floor(result + 0.5)

Review Comment:
   Fixed — now rejects NaN/inf and clamps negative results to 0. Added test 
case (TEST 13) to verify.



##########
apisix/plugins/ai-rate-limiting.lua:
##########
@@ -61,10 +65,19 @@ local schema = {
         show_limit_quota_header = {type = "boolean", default = true},
         limit_strategy = {
             type = "string",
-            enum = {"total_tokens", "prompt_tokens", "completion_tokens"},
+            enum = {"total_tokens", "prompt_tokens", "completion_tokens", 
"expression"},
             default = "total_tokens",
             description = "The strategy to limit the tokens"
         },
+        cost_expr = {
+            type = "string",
+            minLength = 1,
+            description = "Lua arithmetic expression for dynamic token cost 
calculation. "
+                .. "Variables are injected from the LLM API raw usage response 
fields. "
+                .. "Missing variables default to 0. "
+                .. "Only valid when limit_strategy is 'expression'. "
+                .. "Example: input_tokens + cache_creation_input_tokens + 
output_tokens",
+        },

Review Comment:
   Having cost_expr set when strategy is not expression is harmless (ignored). 
Rejecting it would make config management harder (e.g., switching strategies 
would require removing/adding the field). This follows the 
be-liberal-in-what-you-accept principle.



##########
t/plugin/ai-rate-limiting-expression.t:
##########
@@ -0,0 +1,551 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+BEGIN {
+    $ENV{TEST_ENABLE_CONTROL_API_V1} = "0";
+}
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_shuffle();
+no_root_location();
+
+add_block_preprocessor(sub {
+    my ($block) = @_;
+
+    if (!defined $block->request) {
+        $block->set_value("request", "GET /t");
+    }
+
+    my $http_config = $block->http_config // <<_EOC_;
+        server {
+            server_name anthropic;
+            listen 16725;
+
+            default_type 'application/json';
+
+            location /v1/messages {
+                content_by_lua_block {
+                    local json = require("cjson.safe")
+                    local ngx = ngx
+
+                    ngx.req.read_body()
+                    local body = ngx.req.get_body_data()
+                    body = json.decode(body)
+
+                    if not body or not body.messages then
+                        ngx.status = 400
+                        
ngx.say('{"type":"error","error":{"type":"invalid_request_error","message":"missing
 messages"}}')
+                        return
+                    end
+
+                    local api_key = ngx.req.get_headers()["x-api-key"]
+                    if api_key ~= "test-key" then
+                        ngx.status = 401
+                        
ngx.say('{"type":"error","error":{"type":"authentication_error","message":"invalid
 x-api-key"}}')
+                        return
+                    end
+
+                    if body.stream then
+                        ngx.header["Content-Type"] = "text/event-stream"
+
+                        -- message_start with input_tokens and cache tokens
+                        local message_start = json.encode({
+                            type = "message_start",
+                            message = {
+                                id = "msg_test123",
+                                type = "message",
+                                role = "assistant",
+                                model = body.model or 
"claude-sonnet-4-20250514",
+                                content = {},
+                                usage = {
+                                    input_tokens = 50,
+                                    output_tokens = 0,
+                                    cache_creation_input_tokens = 100,
+                                    cache_read_input_tokens = 200,
+                                },
+                            },
+                        })
+                        ngx.say("event: message_start")
+                        ngx.say("data: " .. message_start)
+                        ngx.say("")
+
+                        -- content_block_start
+                        ngx.say("event: content_block_start")
+                        ngx.say('data: 
{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}')
+                        ngx.say("")
+
+                        -- content_block_delta
+                        ngx.say("event: content_block_delta")
+                        ngx.say('data: 
{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello
 from Claude!"}}')
+                        ngx.say("")
+
+                        -- content_block_stop
+                        ngx.say("event: content_block_stop")
+                        ngx.say('data: 
{"type":"content_block_stop","index":0}')
+                        ngx.say("")
+
+                        -- message_delta with output_tokens
+                        local message_delta = json.encode({
+                            type = "message_delta",
+                            delta = { stop_reason = "end_turn" },
+                            usage = {
+                                output_tokens = 30,
+                            },
+                        })
+                        ngx.say("event: message_delta")
+                        ngx.say("data: " .. message_delta)
+                        ngx.say("")
+
+                        -- message_stop
+                        ngx.say("event: message_stop")
+                        ngx.say("data: {}")
+                        ngx.say("")
+                    else
+                        ngx.status = 200
+                        ngx.say(json.encode({
+                            id = "msg_test456",
+                            type = "message",
+                            role = "assistant",
+                            model = body.model or "claude-sonnet-4-20250514",
+                            content = {{
+                                type = "text",
+                                text = "Hello from Claude!",
+                            }},
+                            stop_reason = "end_turn",
+                            usage = {
+                                input_tokens = 50,
+                                output_tokens = 30,
+                                cache_creation_input_tokens = 100,
+                                cache_read_input_tokens = 200,
+                            },
+                        }))
+                    end
+                }
+            }
+        }
+_EOC_
+
+    $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: schema validation - expression strategy requires cost_expr
+--- config
+    location /t {
+        content_by_lua_block {
+            local plugin = require("apisix.plugins.ai-rate-limiting")
+            local configs = {
+                -- expression without cost_expr
+                {
+                    limit = 100,
+                    time_window = 60,
+                    limit_strategy = "expression",
+                },
+                -- expression with empty cost_expr
+                {
+                    limit = 100,
+                    time_window = 60,
+                    limit_strategy = "expression",
+                    cost_expr = "",
+                },
+                -- expression with invalid cost_expr syntax
+                {
+                    limit = 100,
+                    time_window = 60,
+                    limit_strategy = "expression",
+                    cost_expr = "invalid $$$ syntax %%%",
+                },
+                -- valid expression
+                {
+                    limit = 100,
+                    time_window = 60,
+                    limit_strategy = "expression",
+                    cost_expr = "input_tokens + output_tokens",
+                },
+                -- valid complex expression
+                {
+                    limit = 100,
+                    time_window = 60,
+                    limit_strategy = "expression",
+                    cost_expr = "(input_tokens - cache_read_input_tokens) + 
cache_creation_input_tokens * 1.25 + output_tokens",
+                },

Review Comment:
   Added TEST 12/13 — expression input_tokens - cache_read_input_tokens yields 
-150, which is clamped to 0. Remaining stays at 99 for both requests.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to