This is an automated email from the ASF dual-hosted git repository.
chengpan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git
The following commit(s) were added to refs/heads/master by this push:
new 067c6010a [KYUUBI #4554] [CHAT] Code improvement in ChatGPTProvider
067c6010a is described below
commit 067c6010a668c1ed9e83a991e589fb886304c959
Author: liangbowen <[email protected]>
AuthorDate: Sat Mar 18 15:34:16 2023 +0800
[KYUUBI #4554] [CHAT] Code improvement in ChatGPTProvider
### _Why are the changes needed?_
- set authentication as default header in client construction instead of
request construction
- handle response's status code in scala style
- transforming config's long value to int with `.intValue` instead of
`asInstanceOf` casting
- fix var name to `response`
### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including
negative and positive cases if possible
- [ ] Add screenshots for manual tests if appropriate
- [ ] [Run
test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests)
locally before make a pull request
Closes #4554 from bowenliang123/chatgpt-http.
Closes #4554
114484a4d [liangbowen] httpclient improvement in ChatGPTProvider
Authored-by: liangbowen <[email protected]>
Signed-off-by: Cheng Pan <[email protected]>
---
.../engine/chat/provider/ChatGPTProvider.scala | 42 ++++++++++++----------
1 file changed, 24 insertions(+), 18 deletions(-)
diff --git
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
index 28ef36bb4..a4cdb7c94 100644
---
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
+++
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
@@ -20,12 +20,15 @@ package org.apache.kyuubi.engine.chat.provider
import java.util
import java.util.concurrent.TimeUnit
+import scala.collection.JavaConverters._
+
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
-import org.apache.http.{HttpHost, HttpStatus}
+import org.apache.http.{HttpHeaders, HttpHost, HttpStatus}
import org.apache.http.client.config.RequestConfig
import org.apache.http.client.methods.HttpPost
import org.apache.http.entity.{ContentType, StringEntity}
import org.apache.http.impl.client.{CloseableHttpClient, HttpClientBuilder}
+import org.apache.http.message.BasicHeader
import org.apache.http.util.EntityUtils
import org.apache.kyuubi.config.KyuubiConf
@@ -39,11 +42,16 @@ class ChatGPTProvider(conf: KyuubiConf) extends
ChatProvider {
s"which could be got at https://platform.openai.com/account/api-keys")
}
- private val httpClient: CloseableHttpClient =
HttpClientBuilder.create().build()
+ private val httpClient: CloseableHttpClient = {
+ HttpClientBuilder.create()
+ .setDefaultHeaders(List(
+ new BasicHeader(HttpHeaders.AUTHORIZATION, s"Bearer
$gptApiKey")).asJava)
+ .build()
+ }
- private val requestConfig = {
- val connectTimeout =
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).asInstanceOf[Int]
- val socketTimeout =
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).asInstanceOf[Int]
+ private val requestConfig: RequestConfig = {
+ val connectTimeout =
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).intValue()
+ val socketTimeout =
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).intValue()
val builder: RequestConfig.Builder = RequestConfig.custom()
.setConnectTimeout(connectTimeout)
.setSocketTimeout(socketTimeout)
@@ -70,8 +78,6 @@ class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider {
messages.addLast(Message("user", q))
val request = new HttpPost("https://api.openai.com/v1/chat/completions")
- request.addHeader("Authorization", "Bearer " + gptApiKey)
-
val req = Map(
"messages" -> messages,
"model" -> "gpt-3.5-turbo",
@@ -81,17 +87,17 @@ class ChatGPTProvider(conf: KyuubiConf) extends
ChatProvider {
val entity = new StringEntity(mapper.writeValueAsString(req),
ContentType.APPLICATION_JSON)
request.setEntity(entity)
request.setConfig(requestConfig)
- val responseEntity = httpClient.execute(request)
- val respJson =
mapper.readTree(EntityUtils.toString(responseEntity.getEntity))
- val statusCode = responseEntity.getStatusLine.getStatusCode
- if (responseEntity.getStatusLine.getStatusCode == HttpStatus.SC_OK) {
- val replyMessage = mapper.treeToValue[Message](
- respJson.get("choices").get(0).get("message"))
- messages.addLast(replyMessage)
- replyMessage.content
- } else {
- messages.removeLast()
- s"Chat failed. Status: $statusCode.
${respJson.get("error").get("message").asText}"
+ val response = httpClient.execute(request)
+ val respJson = mapper.readTree(EntityUtils.toString(response.getEntity))
+ response.getStatusLine.getStatusCode match {
+ case HttpStatus.SC_OK =>
+ val replyMessage = mapper.treeToValue[Message](
+ respJson.get("choices").get(0).get("message"))
+ messages.addLast(replyMessage)
+ replyMessage.content
+ case errorStatusCode =>
+ messages.removeLast()
+ s"Chat failed. Status: $errorStatusCode.
${respJson.get("error").get("message").asText}"
}
}