This is an automated email from the ASF dual-hosted git repository.
bowenliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git
The following commit(s) were added to refs/heads/master by this push:
new 24e87ef21 [KYUUBI #4556] [CHAT] Refactor ChatGPTProvider to use
`openai-java` client
24e87ef21 is described below
commit 24e87ef21a4aeb3eaf2a951533b591b56d5f9a52
Author: liangbowen <[email protected]>
AuthorDate: Sun Mar 19 19:24:41 2023 +0800
[KYUUBI #4556] [CHAT] Refactor ChatGPTProvider to use `openai-java` client
### _Why are the changes needed?_
- use Java SDK `openai-java` for ChatGPT which is popular and listed in
official website, https://github.com/TheoKanning/openai-java
- Focus on lifecycle in ChatGPTProvider, and prevent handling lower-level
concepts in details, like POJO mapping, HTTP request handling.
- follow the changes from upstream changes from OpenAI
### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including
negative and positive cases if possible
- [x] Add screenshots for manual tests if appropriate
- [x] [Run
test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests)
locally before make a pull request
Closes #4556 from bowenliang123/chatgpt-third.
Closes #4556
ecf1e2cf6 [liangbowen] manually add `openai-gpt3-java:*` and its dependency
to LICENSE-binary
53b8375a5 [liangbowen] refactor ChatGPTProvider to use `openai-java` SDK
Authored-by: liangbowen <[email protected]>
Signed-off-by: liangbowen <[email protected]>
---
LICENSE-binary | 5 ++
externals/kyuubi-chat-engine/pom.xml | 5 +-
.../engine/chat/provider/ChatGPTProvider.scala | 87 ++++++++++------------
pom.xml | 7 ++
4 files changed, 53 insertions(+), 51 deletions(-)
diff --git a/LICENSE-binary b/LICENSE-binary
index 92daf62ab..feab9965e 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -319,6 +319,8 @@ io.swagger.core.v3:swagger-models
io.vertx:vertx-core
io.vertx:vertx-grpc
org.apache.zookeeper:zookeeper
+com.squareup.retrofit2:retrofit
+com.squareup.okhttp3:okhttp
BSD
------------
@@ -356,6 +358,9 @@ org.codehaus.mojo:animal-sniffer-annotations
org.slf4j:slf4j-api
org.slf4j:jcl-over-slf4j
org.slf4j:jul-over-slf4j
+com.theokanning.openai-gpt3-java:api
+com.theokanning.openai-gpt3-java:client
+com.theokanning.openai-gpt3-java:service
kyuubi-server/src/main/resources/org/apache/kyuubi/ui/static/assets/fonts/*
kyuubi-server/src/main/resources/org/apache/kyuubi/ui/static/icon.min.css
diff --git a/externals/kyuubi-chat-engine/pom.xml
b/externals/kyuubi-chat-engine/pom.xml
index 7e2178918..28779f450 100644
--- a/externals/kyuubi-chat-engine/pom.xml
+++ b/externals/kyuubi-chat-engine/pom.xml
@@ -45,8 +45,9 @@
</dependency>
<dependency>
- <groupId>org.apache.httpcomponents</groupId>
- <artifactId>httpclient</artifactId>
+ <groupId>com.theokanning.openai-gpt3-java</groupId>
+ <artifactId>service</artifactId>
+ <version>${openai.java.version}</version>
</dependency>
<dependency>
diff --git
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
index a4cdb7c94..2e4bf3f8d 100644
---
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
+++
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
@@ -17,22 +17,20 @@
package org.apache.kyuubi.engine.chat.provider
+import java.net.{InetSocketAddress, Proxy, URL}
+import java.time.Duration
import java.util
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
-import org.apache.http.{HttpHeaders, HttpHost, HttpStatus}
-import org.apache.http.client.config.RequestConfig
-import org.apache.http.client.methods.HttpPost
-import org.apache.http.entity.{ContentType, StringEntity}
-import org.apache.http.impl.client.{CloseableHttpClient, HttpClientBuilder}
-import org.apache.http.message.BasicHeader
-import org.apache.http.util.EntityUtils
+import com.theokanning.openai.OpenAiApi
+import com.theokanning.openai.completion.chat.{ChatCompletionRequest,
ChatMessage}
+import com.theokanning.openai.service.OpenAiService
+import com.theokanning.openai.service.OpenAiService.{defaultClient,
defaultObjectMapper, defaultRetrofit}
import org.apache.kyuubi.config.KyuubiConf
-import org.apache.kyuubi.engine.chat.provider.ChatProvider.mapper
class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider {
@@ -42,31 +40,32 @@ class ChatGPTProvider(conf: KyuubiConf) extends
ChatProvider {
s"which could be got at https://platform.openai.com/account/api-keys")
}
- private val httpClient: CloseableHttpClient = {
- HttpClientBuilder.create()
- .setDefaultHeaders(List(
- new BasicHeader(HttpHeaders.AUTHORIZATION, s"Bearer
$gptApiKey")).asJava)
- .build()
- }
+ private val openAiService: OpenAiService = {
+ val builder = defaultClient(
+ gptApiKey,
+
Duration.ofMillis(conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT)))
+ .newBuilder
+
.connectTimeout(Duration.ofMillis(conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT)))
- private val requestConfig: RequestConfig = {
- val connectTimeout =
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).intValue()
- val socketTimeout =
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).intValue()
- val builder: RequestConfig.Builder = RequestConfig.custom()
- .setConnectTimeout(connectTimeout)
- .setSocketTimeout(socketTimeout)
- conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_PROXY).foreach { url =>
- builder.setProxy(HttpHost.create(url))
+ conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_PROXY) match {
+ case Some(httpProxyUrl) =>
+ val url = new URL(httpProxyUrl)
+ val proxy = new Proxy(Proxy.Type.HTTP, new
InetSocketAddress(url.getHost, url.getPort))
+ builder.proxy(proxy)
+ case _ =>
}
- builder.build()
+
+ val retrofit = defaultRetrofit(builder.build(), defaultObjectMapper)
+ val api = retrofit.create(classOf[OpenAiApi])
+ new OpenAiService(api)
}
- private val chatHistory: LoadingCache[String, util.ArrayDeque[Message]] =
+ private val chatHistory: LoadingCache[String, util.ArrayDeque[ChatMessage]] =
CacheBuilder.newBuilder()
.expireAfterWrite(10, TimeUnit.MINUTES)
- .build(new CacheLoader[String, util.ArrayDeque[Message]] {
- override def load(sessionId: String): util.ArrayDeque[Message] =
- new util.ArrayDeque[Message]
+ .build(new CacheLoader[String, util.ArrayDeque[ChatMessage]] {
+ override def load(sessionId: String): util.ArrayDeque[ChatMessage] =
+ new util.ArrayDeque[ChatMessage]
})
override def open(sessionId: String): Unit = {
@@ -75,29 +74,19 @@ class ChatGPTProvider(conf: KyuubiConf) extends
ChatProvider {
override def ask(sessionId: String, q: String): String = {
val messages = chatHistory.get(sessionId)
- messages.addLast(Message("user", q))
-
- val request = new HttpPost("https://api.openai.com/v1/chat/completions")
- val req = Map(
- "messages" -> messages,
- "model" -> "gpt-3.5-turbo",
- "max_tokens" -> 200,
- "temperature" -> 0.5,
- "top_p" -> 1)
- val entity = new StringEntity(mapper.writeValueAsString(req),
ContentType.APPLICATION_JSON)
- request.setEntity(entity)
- request.setConfig(requestConfig)
- val response = httpClient.execute(request)
- val respJson = mapper.readTree(EntityUtils.toString(response.getEntity))
- response.getStatusLine.getStatusCode match {
- case HttpStatus.SC_OK =>
- val replyMessage = mapper.treeToValue[Message](
- respJson.get("choices").get(0).get("message"))
- messages.addLast(replyMessage)
- replyMessage.content
- case errorStatusCode =>
+ try {
+ messages.addLast(new ChatMessage("user", q))
+ val completionRequest = ChatCompletionRequest.builder()
+ .messages(messages.asScala.toList.asJava)
+ .model("gpt-3.5-turbo")
+ .build()
+ val responseText =
openAiService.createChatCompletion(completionRequest).getChoices.asScala
+ .map(c => c.getMessage.getContent).mkString
+ responseText
+ } catch {
+ case e: Throwable =>
messages.removeLast()
- s"Chat failed. Status: $errorStatusCode.
${respJson.get("error").get("message").asText}"
+ s"Chat failed. Error: ${e.getMessage}"
}
}
diff --git a/pom.xml b/pom.xml
index a2de5d29d..7f6d1ed71 100644
--- a/pom.xml
+++ b/pom.xml
@@ -174,6 +174,7 @@
<log4j.version>2.20.0</log4j.version>
<mysql.jdbc.version>8.0.32</mysql.jdbc.version>
<netty.version>4.1.89.Final</netty.version>
+ <openai.java.version>0.11.1</openai.java.version>
<parquet.version>1.10.1</parquet.version>
<phoenix.version>6.0.0</phoenix.version>
<prometheus.version>0.16.0</prometheus.version>
@@ -1658,6 +1659,12 @@
<artifactId>py4j</artifactId>
<version>${py4j.version}</version>
</dependency>
+
+ <dependency>
+ <groupId>com.theokanning.openai-gpt3-java</groupId>
+ <artifactId>service</artifactId>
+ <version>${openai.java.version}</version>
+ </dependency>
</dependencies>
</dependencyManagement>