This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 067c6010a [KYUUBI #4554] [CHAT] Code improvement in ChatGPTProvider
067c6010a is described below

commit 067c6010a668c1ed9e83a991e589fb886304c959
Author: liangbowen <[email protected]>
AuthorDate: Sat Mar 18 15:34:16 2023 +0800

    [KYUUBI #4554] [CHAT] Code improvement in ChatGPTProvider
    
    ### _Why are the changes needed?_
    
    - set authentication as default header in client construction instead of  
request construction
    - handle response's status code in scala style
    - transforming config's long value to int with `.intValue` instead of 
`asInstanceOf` casting
    - fix var name to `response`
    
    ### _How was this patch tested?_
    - [ ] Add some test cases that check the changes thoroughly including 
negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [ ] [Run 
test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests)
 locally before make a pull request
    
    Closes #4554 from bowenliang123/chatgpt-http.
    
    Closes #4554
    
    114484a4d [liangbowen] httpclient improvement in ChatGPTProvider
    
    Authored-by: liangbowen <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../engine/chat/provider/ChatGPTProvider.scala     | 42 ++++++++++++----------
 1 file changed, 24 insertions(+), 18 deletions(-)

diff --git 
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
 
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
index 28ef36bb4..a4cdb7c94 100644
--- 
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
+++ 
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
@@ -20,12 +20,15 @@ package org.apache.kyuubi.engine.chat.provider
 import java.util
 import java.util.concurrent.TimeUnit
 
+import scala.collection.JavaConverters._
+
 import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
-import org.apache.http.{HttpHost, HttpStatus}
+import org.apache.http.{HttpHeaders, HttpHost, HttpStatus}
 import org.apache.http.client.config.RequestConfig
 import org.apache.http.client.methods.HttpPost
 import org.apache.http.entity.{ContentType, StringEntity}
 import org.apache.http.impl.client.{CloseableHttpClient, HttpClientBuilder}
+import org.apache.http.message.BasicHeader
 import org.apache.http.util.EntityUtils
 
 import org.apache.kyuubi.config.KyuubiConf
@@ -39,11 +42,16 @@ class ChatGPTProvider(conf: KyuubiConf) extends 
ChatProvider {
         s"which could be got at https://platform.openai.com/account/api-keys";)
   }
 
-  private val httpClient: CloseableHttpClient = 
HttpClientBuilder.create().build()
+  private val httpClient: CloseableHttpClient = {
+    HttpClientBuilder.create()
+      .setDefaultHeaders(List(
+        new BasicHeader(HttpHeaders.AUTHORIZATION, s"Bearer 
$gptApiKey")).asJava)
+      .build()
+  }
 
-  private val requestConfig = {
-    val connectTimeout = 
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).asInstanceOf[Int]
-    val socketTimeout = 
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).asInstanceOf[Int]
+  private val requestConfig: RequestConfig = {
+    val connectTimeout = 
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).intValue()
+    val socketTimeout = 
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).intValue()
     val builder: RequestConfig.Builder = RequestConfig.custom()
       .setConnectTimeout(connectTimeout)
       .setSocketTimeout(socketTimeout)
@@ -70,8 +78,6 @@ class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider {
     messages.addLast(Message("user", q))
 
     val request = new HttpPost("https://api.openai.com/v1/chat/completions";)
-    request.addHeader("Authorization", "Bearer " + gptApiKey)
-
     val req = Map(
       "messages" -> messages,
       "model" -> "gpt-3.5-turbo",
@@ -81,17 +87,17 @@ class ChatGPTProvider(conf: KyuubiConf) extends 
ChatProvider {
     val entity = new StringEntity(mapper.writeValueAsString(req), 
ContentType.APPLICATION_JSON)
     request.setEntity(entity)
     request.setConfig(requestConfig)
-    val responseEntity = httpClient.execute(request)
-    val respJson = 
mapper.readTree(EntityUtils.toString(responseEntity.getEntity))
-    val statusCode = responseEntity.getStatusLine.getStatusCode
-    if (responseEntity.getStatusLine.getStatusCode == HttpStatus.SC_OK) {
-      val replyMessage = mapper.treeToValue[Message](
-        respJson.get("choices").get(0).get("message"))
-      messages.addLast(replyMessage)
-      replyMessage.content
-    } else {
-      messages.removeLast()
-      s"Chat failed. Status: $statusCode. 
${respJson.get("error").get("message").asText}"
+    val response = httpClient.execute(request)
+    val respJson = mapper.readTree(EntityUtils.toString(response.getEntity))
+    response.getStatusLine.getStatusCode match {
+      case HttpStatus.SC_OK =>
+        val replyMessage = mapper.treeToValue[Message](
+          respJson.get("choices").get(0).get("message"))
+        messages.addLast(replyMessage)
+        replyMessage.content
+      case errorStatusCode =>
+        messages.removeLast()
+        s"Chat failed. Status: $errorStatusCode. 
${respJson.get("error").get("message").asText}"
     }
   }
 

Reply via email to