This is an automated email from the ASF dual-hosted git repository.
sunyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git
The following commit(s) were added to refs/heads/master by this push:
new 78980127a refactor(ai-proxy): move read_response into
ai_driver.request function (#12101)
78980127a is described below
commit 78980127a70db5734512c5508cd48f4f1fe2c51b
Author: litesun <[email protected]>
AuthorDate: Mon Apr 21 17:55:46 2025 +0800
refactor(ai-proxy): move read_response into ai_driver.request function
(#12101)
---
apisix/plugins/ai-drivers/openai-base.lua | 177 +++++++++++++++++-------------
apisix/plugins/ai-proxy/base.lua | 12 +-
apisix/plugins/ai-proxy/schema.lua | 12 ++
3 files changed, 114 insertions(+), 87 deletions(-)
diff --git a/apisix/plugins/ai-drivers/openai-base.lua
b/apisix/plugins/ai-drivers/openai-base.lua
index a4f061fe4..09134265e 100644
--- a/apisix/plugins/ai-drivers/openai-base.lua
+++ b/apisix/plugins/ai-drivers/openai-base.lua
@@ -35,6 +35,9 @@ local type = type
local ipairs = ipairs
local setmetatable = setmetatable
+local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
+local HTTP_GATEWAY_TIMEOUT = ngx.HTTP_GATEWAY_TIMEOUT
+
function _M.new(opts)
@@ -62,81 +65,19 @@ function _M.validate_request(ctx)
end
-function _M.request(self, conf, request_table, extra_opts)
- local httpc, err = http.new()
- if not httpc then
- return nil, "failed to create http client to send request to LLM
server: " .. err
- end
- httpc:set_timeout(conf.timeout)
-
- local endpoint = extra_opts.endpoint
- local parsed_url
- if endpoint then
- parsed_url = url.parse(endpoint)
- end
-
- local ok, err = httpc:connect({
- scheme = parsed_url and parsed_url.scheme or "https",
- host = parsed_url and parsed_url.host or self.host,
- port = parsed_url and parsed_url.port or self.port,
- ssl_verify = conf.ssl_verify,
- ssl_server_name = parsed_url and parsed_url.host or self.host,
- pool_size = conf.keepalive and conf.keepalive_pool,
- })
-
- if not ok then
- return nil, "failed to connect to LLM server: " .. err
- end
-
- local query_params = extra_opts.query_params
-
- if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query
> 0 then
- local args_tab = core.string.decode_args(parsed_url.query)
- if type(args_tab) == "table" then
- core.table.merge(query_params, args_tab)
- end
+local function handle_error(err)
+ if core.string.find(err, "timeout") then
+ return HTTP_GATEWAY_TIMEOUT
end
-
- local path = (parsed_url and parsed_url.path or self.path)
-
- local headers = extra_opts.headers
- headers["Content-Type"] = "application/json"
- local params = {
- method = "POST",
- headers = headers,
- keepalive = conf.keepalive,
- ssl_verify = conf.ssl_verify,
- path = path,
- query = query_params
- }
-
- if extra_opts.model_options then
- for opt, val in pairs(extra_opts.model_options) do
- request_table[opt] = val
- end
- end
-
- local req_json, err = core.json.encode(request_table)
- if not req_json then
- return nil, err
- end
-
- params.body = req_json
-
- local res, err = httpc:request(params)
- if not res then
- return nil, err
- end
-
- return res, nil
+ return HTTP_INTERNAL_SERVER_ERROR
end
-function _M.read_response(ctx, res)
+local function read_response(ctx, res)
local body_reader = res.body_reader
if not body_reader then
core.log.warn("AI service sent no response body")
- return 500
+ return HTTP_INTERNAL_SERVER_ERROR
end
local content_type = res.headers["Content-Type"]
@@ -147,10 +88,7 @@ function _M.read_response(ctx, res)
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.warn("failed to read response chunk: ", err)
- if core.string.find(err, "timeout") then
- return 504
- end
- return 500
+ return handle_error(err)
end
if not chunk then
return
@@ -206,10 +144,7 @@ function _M.read_response(ctx, res)
local raw_res_body, err = res:read_body()
if not raw_res_body then
core.log.warn("failed to read response body: ", err)
- if core.string.find(err, "timeout") then
- return 504
- end
- return 500
+ return handle_error(err)
end
local res_body, err = core.json.decode(raw_res_body)
if err then
@@ -227,4 +162,94 @@ function _M.read_response(ctx, res)
end
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ local httpc, err = http.new()
+ if not httpc then
+ core.log.error("failed to create http client to send request to LLM
server: ", err)
+ return HTTP_INTERNAL_SERVER_ERROR
+ end
+ httpc:set_timeout(conf.timeout)
+
+ local endpoint = extra_opts.endpoint
+ local parsed_url
+ if endpoint then
+ parsed_url = url.parse(endpoint)
+ end
+
+ local scheme = parsed_url and parsed_url.scheme or "https"
+ local host = parsed_url and parsed_url.host or self.host
+ local port = parsed_url and parsed_url.port
+ if not port then
+ if scheme == "https" then
+ port = 443
+ else
+ port = 80
+ end
+ end
+ local ok, err = httpc:connect({
+ scheme = scheme,
+ host = host,
+ port = port,
+ ssl_verify = conf.ssl_verify,
+ ssl_server_name = parsed_url and parsed_url.host or self.host,
+ })
+
+ if not ok then
+ core.log.warn("failed to connect to LLM server: ", err)
+ return handle_error(err)
+ end
+
+ local query_params = extra_opts.query_params
+
+ if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query
> 0 then
+ local args_tab = core.string.decode_args(parsed_url.query)
+ if type(args_tab) == "table" then
+ core.table.merge(query_params, args_tab)
+ end
+ end
+
+ local path = (parsed_url and parsed_url.path or self.path)
+
+ local headers = extra_opts.headers
+ headers["Content-Type"] = "application/json"
+ local params = {
+ method = "POST",
+ headers = headers,
+ ssl_verify = conf.ssl_verify,
+ path = path,
+ query = query_params
+ }
+
+ if extra_opts.model_options then
+ for opt, val in pairs(extra_opts.model_options) do
+ request_table[opt] = val
+ end
+ end
+
+ local req_json, err = core.json.encode(request_table)
+ if not req_json then
+ return nil, err
+ end
+
+ params.body = req_json
+
+ local res, err = httpc:request(params)
+ if not res then
+ core.log.warn("failed to send request to LLM server: ", err)
+ return handle_error(err)
+ end
+
+ local code, body = read_response(ctx, res)
+
+ if conf.keepalive then
+ local ok, err = httpc:set_keepalive(conf.keepalive_timeout,
conf.keepalive_pool)
+ if not ok then
+ core.log.warn("failed to keepalive connection: ", err)
+ end
+ end
+
+ return code, body
+end
+
+
return _M
diff --git a/apisix/plugins/ai-proxy/base.lua b/apisix/plugins/ai-proxy/base.lua
index 73c683c0a..907626039 100644
--- a/apisix/plugins/ai-proxy/base.lua
+++ b/apisix/plugins/ai-proxy/base.lua
@@ -18,7 +18,6 @@
local core = require("apisix.core")
local require = require
local bad_request = ngx.HTTP_BAD_REQUEST
-local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local _M = {}
@@ -44,16 +43,7 @@ function _M.before_proxy(conf, ctx)
}
end
- local res, err = ai_driver:request(conf, request_body, extra_opts)
- if not res then
- core.log.warn("failed to send request to AI service: ", err)
- if core.string.find(err, "timeout") then
- return 504
- end
- return internal_server_error
- end
-
- return ai_driver.read_response(ctx, res)
+ return ai_driver:request(ctx, conf, request_body, extra_opts)
end
diff --git a/apisix/plugins/ai-proxy/schema.lua
b/apisix/plugins/ai-proxy/schema.lua
index 1bd44da04..d0e8f23cb 100644
--- a/apisix/plugins/ai-proxy/schema.lua
+++ b/apisix/plugins/ai-proxy/schema.lua
@@ -108,6 +108,12 @@ _M.ai_proxy_schema = {
description = "timeout in milliseconds",
},
keepalive = {type = "boolean", default = true},
+ keepalive_timeout = {
+ type = "integer",
+ minimum = 1000,
+ default = 60000,
+ description = "keepalive timeout in milliseconds",
+ },
keepalive_pool = {type = "integer", minimum = 1, default = 30},
ssl_verify = {type = "boolean", default = true },
override = {
@@ -164,6 +170,12 @@ _M.ai_proxy_multi_schema = {
description = "timeout in milliseconds",
},
keepalive = {type = "boolean", default = true},
+ keepalive_timeout = {
+ type = "integer",
+ minimum = 1000,
+ default = 60000,
+ description = "keepalive timeout in milliseconds",
+ },
keepalive_pool = {type = "integer", minimum = 1, default = 30},
ssl_verify = {type = "boolean", default = true },
},