This is an automated email from the ASF dual-hosted git repository.

sunyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git


The following commit(s) were added to refs/heads/master by this push:
     new 78980127a refactor(ai-proxy): move read_response into 
ai_driver.request function (#12101)
78980127a is described below

commit 78980127a70db5734512c5508cd48f4f1fe2c51b
Author: litesun <[email protected]>
AuthorDate: Mon Apr 21 17:55:46 2025 +0800

    refactor(ai-proxy): move read_response into ai_driver.request function 
(#12101)
---
 apisix/plugins/ai-drivers/openai-base.lua | 177 +++++++++++++++++-------------
 apisix/plugins/ai-proxy/base.lua          |  12 +-
 apisix/plugins/ai-proxy/schema.lua        |  12 ++
 3 files changed, 114 insertions(+), 87 deletions(-)

diff --git a/apisix/plugins/ai-drivers/openai-base.lua 
b/apisix/plugins/ai-drivers/openai-base.lua
index a4f061fe4..09134265e 100644
--- a/apisix/plugins/ai-drivers/openai-base.lua
+++ b/apisix/plugins/ai-drivers/openai-base.lua
@@ -35,6 +35,9 @@ local type  = type
 local ipairs = ipairs
 local setmetatable = setmetatable
 
+local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
+local HTTP_GATEWAY_TIMEOUT = ngx.HTTP_GATEWAY_TIMEOUT
+
 
 function _M.new(opts)
 
@@ -62,81 +65,19 @@ function _M.validate_request(ctx)
 end
 
 
-function _M.request(self, conf, request_table, extra_opts)
-    local httpc, err = http.new()
-    if not httpc then
-        return nil, "failed to create http client to send request to LLM 
server: " .. err
-    end
-    httpc:set_timeout(conf.timeout)
-
-    local endpoint = extra_opts.endpoint
-    local parsed_url
-    if endpoint then
-        parsed_url = url.parse(endpoint)
-    end
-
-    local ok, err = httpc:connect({
-        scheme = parsed_url and parsed_url.scheme or "https",
-        host = parsed_url and parsed_url.host or self.host,
-        port = parsed_url and parsed_url.port or self.port,
-        ssl_verify = conf.ssl_verify,
-        ssl_server_name = parsed_url and parsed_url.host or self.host,
-        pool_size = conf.keepalive and conf.keepalive_pool,
-    })
-
-    if not ok then
-        return nil, "failed to connect to LLM server: " .. err
-    end
-
-    local query_params = extra_opts.query_params
-
-    if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query 
> 0 then
-        local args_tab = core.string.decode_args(parsed_url.query)
-        if type(args_tab) == "table" then
-            core.table.merge(query_params, args_tab)
-        end
+local function handle_error(err)
+    if core.string.find(err, "timeout") then
+        return HTTP_GATEWAY_TIMEOUT
     end
-
-    local path = (parsed_url and parsed_url.path or self.path)
-
-    local headers = extra_opts.headers
-    headers["Content-Type"] = "application/json"
-    local params = {
-        method = "POST",
-        headers = headers,
-        keepalive = conf.keepalive,
-        ssl_verify = conf.ssl_verify,
-        path = path,
-        query = query_params
-    }
-
-    if extra_opts.model_options then
-        for opt, val in pairs(extra_opts.model_options) do
-            request_table[opt] = val
-        end
-    end
-
-    local req_json, err = core.json.encode(request_table)
-    if not req_json then
-        return nil, err
-    end
-
-    params.body = req_json
-
-    local res, err = httpc:request(params)
-    if not res then
-        return nil, err
-    end
-
-    return res, nil
+    return HTTP_INTERNAL_SERVER_ERROR
 end
 
 
-function _M.read_response(ctx, res)
+local function read_response(ctx, res)
     local body_reader = res.body_reader
     if not body_reader then
         core.log.warn("AI service sent no response body")
-        return 500
+        return HTTP_INTERNAL_SERVER_ERROR
     end
 
     local content_type = res.headers["Content-Type"]
@@ -147,10 +88,7 @@ function _M.read_response(ctx, res)
             local chunk, err = body_reader() -- will read chunk by chunk
             if err then
                 core.log.warn("failed to read response chunk: ", err)
-                if core.string.find(err, "timeout") then
-                    return 504
-                end
-                return 500
+                return handle_error(err)
             end
             if not chunk then
                 return
@@ -206,10 +144,7 @@ function _M.read_response(ctx, res)
     local raw_res_body, err = res:read_body()
     if not raw_res_body then
         core.log.warn("failed to read response body: ", err)
-        if core.string.find(err, "timeout") then
-            return 504
-        end
-        return 500
+        return handle_error(err)
     end
     local res_body, err = core.json.decode(raw_res_body)
     if err then
@@ -227,4 +162,94 @@ function _M.read_response(ctx, res)
 end
 
 
+function _M.request(self, ctx, conf, request_table, extra_opts)
+    local httpc, err = http.new()
+    if not httpc then
+        core.log.error("failed to create http client to send request to LLM 
server: ", err)
+        return HTTP_INTERNAL_SERVER_ERROR
+    end
+    httpc:set_timeout(conf.timeout)
+
+    local endpoint = extra_opts.endpoint
+    local parsed_url
+    if endpoint then
+        parsed_url = url.parse(endpoint)
+    end
+
+    local scheme = parsed_url and parsed_url.scheme or "https"
+    local host = parsed_url and parsed_url.host or self.host
+    local port = parsed_url and parsed_url.port
+    if not port then
+        if scheme == "https" then
+            port = 443
+        else
+            port = 80
+        end
+    end
+    local ok, err = httpc:connect({
+        scheme = scheme,
+        host = host,
+        port = port,
+        ssl_verify = conf.ssl_verify,
+        ssl_server_name = parsed_url and parsed_url.host or self.host,
+    })
+
+    if not ok then
+        core.log.warn("failed to connect to LLM server: ", err)
+        return handle_error(err)
+    end
+
+    local query_params = extra_opts.query_params
+
+    if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query 
> 0 then
+        local args_tab = core.string.decode_args(parsed_url.query)
+        if type(args_tab) == "table" then
+            core.table.merge(query_params, args_tab)
+        end
+    end
+
+    local path = (parsed_url and parsed_url.path or self.path)
+
+    local headers = extra_opts.headers
+    headers["Content-Type"] = "application/json"
+    local params = {
+        method = "POST",
+        headers = headers,
+        ssl_verify = conf.ssl_verify,
+        path = path,
+        query = query_params
+    }
+
+    if extra_opts.model_options then
+        for opt, val in pairs(extra_opts.model_options) do
+            request_table[opt] = val
+        end
+    end
+
+    local req_json, err = core.json.encode(request_table)
+    if not req_json then
+        return nil, err
+    end
+
+    params.body = req_json
+
+    local res, err = httpc:request(params)
+    if not res then
+        core.log.warn("failed to send request to LLM server: ", err)
+        return handle_error(err)
+    end
+
+    local code, body = read_response(ctx, res)
+
+    if conf.keepalive then
+        local ok, err = httpc:set_keepalive(conf.keepalive_timeout, 
conf.keepalive_pool)
+        if not ok then
+            core.log.warn("failed to keepalive connection: ", err)
+        end
+    end
+
+    return code, body
+end
+
+
 return _M
diff --git a/apisix/plugins/ai-proxy/base.lua b/apisix/plugins/ai-proxy/base.lua
index 73c683c0a..907626039 100644
--- a/apisix/plugins/ai-proxy/base.lua
+++ b/apisix/plugins/ai-proxy/base.lua
@@ -18,7 +18,6 @@
 local core = require("apisix.core")
 local require = require
 local bad_request = ngx.HTTP_BAD_REQUEST
-local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
 
 local _M = {}
 
@@ -44,16 +43,7 @@ function _M.before_proxy(conf, ctx)
         }
     end
 
-    local res, err = ai_driver:request(conf, request_body, extra_opts)
-    if not res then
-        core.log.warn("failed to send request to AI service: ", err)
-        if core.string.find(err, "timeout") then
-            return 504
-        end
-        return internal_server_error
-    end
-
-    return ai_driver.read_response(ctx, res)
+    return ai_driver:request(ctx, conf, request_body, extra_opts)
 end
 
 
diff --git a/apisix/plugins/ai-proxy/schema.lua 
b/apisix/plugins/ai-proxy/schema.lua
index 1bd44da04..d0e8f23cb 100644
--- a/apisix/plugins/ai-proxy/schema.lua
+++ b/apisix/plugins/ai-proxy/schema.lua
@@ -108,6 +108,12 @@ _M.ai_proxy_schema = {
             description = "timeout in milliseconds",
         },
         keepalive = {type = "boolean", default = true},
+        keepalive_timeout = {
+            type = "integer",
+            minimum = 1000,
+            default = 60000,
+            description = "keepalive timeout in milliseconds",
+        },
         keepalive_pool = {type = "integer", minimum = 1, default = 30},
         ssl_verify = {type = "boolean", default = true },
         override = {
@@ -164,6 +170,12 @@ _M.ai_proxy_multi_schema = {
             description = "timeout in milliseconds",
         },
         keepalive = {type = "boolean", default = true},
+        keepalive_timeout = {
+            type = "integer",
+            minimum = 1000,
+            default = 60000,
+            description = "keepalive timeout in milliseconds",
+        },
         keepalive_pool = {type = "integer", minimum = 1, default = 30},
         ssl_verify = {type = "boolean", default = true },
     },

Reply via email to