This is an automated email from the ASF dual-hosted git repository.

menghaoran pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new dee989c  Merge error packet processors into PostgreSQLErrPacketFactory 
(#10097)
dee989c is described below

commit dee989c66b714bcf3381eec134be62467f751017
Author: Liang Zhang <[email protected]>
AuthorDate: Thu Apr 15 16:45:34 2021 +0800

    Merge error packet processors into PostgreSQLErrPacketFactory (#10097)
    
    * Refactor PostgreSQLAuthenticationEngine
    
    * Refactor PostgreSQLErrPacketFactory
    
    * Refactor PostgreSQLCommandExecuteEngine
    
    * Refactor PostgreSQLAuthenticationEngine
---
 .../generic/PostgreSQLErrorResponsePacket.java     | 32 ++++++++++++++--------
 .../generic/PostgreSQLErrorResponsePacketTest.java | 25 +++++++++++++----
 .../PostgreSQLAuthenticationEngine.java            | 24 +++++++++-------
 .../command/PostgreSQLCommandExecuteEngine.java    |  6 ++--
 .../postgresql/err/PostgreSQLErrPacketFactory.java | 32 ++++++++++++----------
 5 files changed, 74 insertions(+), 45 deletions(-)

diff --git 
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLErrorResponsePacket.java
 
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLErrorResponsePacket.java
index bf8f8ad..3f11bdf 100644
--- 
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLErrorResponsePacket.java
+++ 
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLErrorResponsePacket.java
@@ -22,9 +22,10 @@ import 
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.Postgr
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLMessagePacketType;
 import 
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
 
-import java.util.HashMap;
+import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.stream.Collectors;
 
 /**
  * Error response packet for PostgreSQL.
@@ -69,16 +70,7 @@ public final class PostgreSQLErrorResponsePacket implements 
PostgreSQLIdentifier
     
     public static final char FIELD_TYPE_ROUTINE = 'R';
     
-    private final Map<Character, String> fields = new HashMap<>();
-    
-    @Override
-    public void write(final PostgreSQLPacketPayload payload) {
-        for (Entry<Character, String> each : fields.entrySet()) {
-            payload.writeInt1(each.getKey());
-            payload.writeStringNul(each.getValue());
-        }
-        payload.writeInt1(0);
-    }
+    private final Map<Character, String> fields = new LinkedHashMap<>();
     
     /**
      * Add field.
@@ -90,6 +82,24 @@ public final class PostgreSQLErrorResponsePacket implements 
PostgreSQLIdentifier
         fields.put(fieldType, fieldValue);
     }
     
+    /**
+     * To server error message.
+     * 
+     * @return server error message
+     */
+    public String toServerErrorMessage() {
+        return fields.entrySet().stream().map(entry -> entry.getKey() + 
entry.getValue()).collect(Collectors.joining("\0"));
+    }
+    
+    @Override
+    public void write(final PostgreSQLPacketPayload payload) {
+        for (Entry<Character, String> each : fields.entrySet()) {
+            payload.writeInt1(each.getKey());
+            payload.writeStringNul(each.getValue());
+        }
+        payload.writeInt1(0);
+    }
+    
     @Override
     public PostgreSQLIdentifierTag getIdentifier() {
         return PostgreSQLMessagePacketType.ERROR_RESPONSE;
diff --git 
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLErrorResponsePacketTest.java
 
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLErrorResponsePacketTest.java
index 090e375..fd404d8 100644
--- 
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLErrorResponsePacketTest.java
+++ 
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLErrorResponsePacketTest.java
@@ -23,20 +23,25 @@ import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertThat;
 import static org.mockito.Mockito.verify;
 
 @RunWith(MockitoJUnitRunner.class)
 public final class PostgreSQLErrorResponsePacketTest {
-
+    
     @Mock
     private PostgreSQLPacketPayload payload;
-
+    
+    @Test
+    public void assertToServerErrorMessage() {
+        PostgreSQLErrorResponsePacket responsePacket = 
createErrorResponsePacket();
+        assertThat(responsePacket.toServerErrorMessage(), 
is("SFATAL\0C3D000\0Mdatabase \"test\" does not exist"));
+    }
+    
     @Test
     public void assertWrite() {
-        PostgreSQLErrorResponsePacket responsePacket = new 
PostgreSQLErrorResponsePacket();
-        
responsePacket.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY, 
"FATAL");
-        responsePacket.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, 
"3D000");
-        
responsePacket.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
"database \"test\" does not exist");
+        PostgreSQLErrorResponsePacket responsePacket = 
createErrorResponsePacket();
         responsePacket.write(payload);
         
verify(payload).writeInt1(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY);
         verify(payload).writeStringNul("FATAL");
@@ -46,4 +51,12 @@ public final class PostgreSQLErrorResponsePacketTest {
         verify(payload).writeStringNul("database \"test\" does not exist");
         verify(payload).writeInt1(0);
     }
+    
+    private PostgreSQLErrorResponsePacket createErrorResponsePacket() {
+        PostgreSQLErrorResponsePacket result = new 
PostgreSQLErrorResponsePacket();
+        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY, 
"FATAL");
+        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, 
"3D000");
+        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
"database \"test\" does not exist");
+        return result;
+    }
 }
diff --git 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
index fd88d0b..e207b2e 100644
--- 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
+++ 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
@@ -31,12 +31,16 @@ import 
org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.Postgre
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLPasswordMessagePacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLRandomGenerator;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLNegativePacket;
+import 
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLMessagePacketType;
 import 
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
 import 
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
 import 
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResult;
 import 
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResultBuilder;
 import 
org.apache.shardingsphere.proxy.frontend.connection.ConnectionIdGenerator;
+import 
org.apache.shardingsphere.proxy.frontend.postgresql.err.PostgreSQLErrPacketFactory;
+import org.postgresql.util.PSQLException;
+import org.postgresql.util.ServerErrorMessage;
 
 /**
  * Authentication engine for PostgreSQL.
@@ -67,12 +71,12 @@ public final class PostgreSQLAuthenticationEngine 
implements AuthenticationEngin
             return AuthenticationResultBuilder.continued();
         }
         payload.getByteBuf().resetReaderIndex();
-        return startupMessageReceived ? afterStartupMessage(context, 
(PostgreSQLPacketPayload) payload) : beforeStartupMessage(context, 
(PostgreSQLPacketPayload) payload);
+        return startupMessageReceived ? processPasswordMessage(context, 
(PostgreSQLPacketPayload) payload) : processStartupMessage(context, 
(PostgreSQLPacketPayload) payload);
     }
     
-    private AuthenticationResult beforeStartupMessage(final 
ChannelHandlerContext context, final PostgreSQLPacketPayload payload) {
-        PostgreSQLComStartupPacket comStartupPacket = new 
PostgreSQLComStartupPacket(payload);
+    private AuthenticationResult processStartupMessage(final 
ChannelHandlerContext context, final PostgreSQLPacketPayload payload) {
         startupMessageReceived = true;
+        PostgreSQLComStartupPacket comStartupPacket = new 
PostgreSQLComStartupPacket(payload);
         String database = comStartupPacket.getDatabase();
         if (!Strings.isNullOrEmpty(database) && 
!ProxyContext.getInstance().schemaExists(database)) {
             
context.writeAndFlush(createErrorPacket(PostgreSQLErrorCode.INVALID_CATALOG_NAME,
 String.format("database \"%s\" does not exist", database)));
@@ -91,9 +95,9 @@ public final class PostgreSQLAuthenticationEngine implements 
AuthenticationEngin
         return currentAuthResult;
     }
     
-    private AuthenticationResult afterStartupMessage(final 
ChannelHandlerContext context, final PostgreSQLPacketPayload payload) {
+    private AuthenticationResult processPasswordMessage(final 
ChannelHandlerContext context, final PostgreSQLPacketPayload payload) {
         char messageType = (char) payload.readInt1();
-        if ('p' != messageType) {
+        if (PostgreSQLMessagePacketType.PASSWORD_MESSAGE.getValue() != 
messageType) {
             PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(
                     
PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION, 
String.format("PasswordMessage is expected, message type 'p', but not '%s'", 
messageType));
             context.writeAndFlush(responsePacket);
@@ -120,10 +124,10 @@ public final class PostgreSQLAuthenticationEngine 
implements AuthenticationEngin
     }
     
     private PostgreSQLErrorResponsePacket createErrorPacket(final 
PostgreSQLErrorCode errorCode, final String errorMessage) {
-        PostgreSQLErrorResponsePacket result = new 
PostgreSQLErrorResponsePacket();
-        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY, 
"FATAL");
-        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, 
errorCode.getErrorCode());
-        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
Strings.isNullOrEmpty(errorMessage) ? errorCode.getConditionName() : 
errorMessage);
-        return result;
+        PostgreSQLErrorResponsePacket packet = new 
PostgreSQLErrorResponsePacket();
+        packet.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY, 
"FATAL");
+        packet.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, 
errorCode.getErrorCode());
+        packet.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
Strings.isNullOrEmpty(errorMessage) ? errorCode.getConditionName() : 
errorMessage);
+        return PostgreSQLErrPacketFactory.newInstance(new PSQLException(new 
ServerErrorMessage(packet.toServerErrorMessage())));
     }
 }
diff --git 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecuteEngine.java
 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecuteEngine.java
index 3e3da17..a6e5971 100644
--- 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecuteEngine.java
+++ 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecuteEngine.java
@@ -27,7 +27,6 @@ import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQ
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacketType;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacketTypeLoader;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLCommandCompletePacket;
-import 
org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLErrorResponsePacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLReadyForQueryPacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
 import 
org.apache.shardingsphere.infra.config.properties.ConfigurationPropertyKey;
@@ -37,6 +36,7 @@ import 
org.apache.shardingsphere.proxy.frontend.command.CommandExecuteEngine;
 import 
org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
 import 
org.apache.shardingsphere.proxy.frontend.command.executor.QueryCommandExecutor;
 import org.apache.shardingsphere.proxy.frontend.command.executor.ResponseType;
+import 
org.apache.shardingsphere.proxy.frontend.postgresql.err.PostgreSQLErrPacketFactory;
 
 import java.sql.SQLException;
 import java.util.Optional;
@@ -63,9 +63,7 @@ public final class PostgreSQLCommandExecuteEngine implements 
CommandExecuteEngin
     
     @Override
     public DatabasePacket<?> getErrorPacket(final Exception cause) {
-        PostgreSQLErrorResponsePacket result = new 
PostgreSQLErrorResponsePacket();
-        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
cause.getMessage());
-        return result;
+        return PostgreSQLErrPacketFactory.newInstance(cause);
     }
     
     @Override
diff --git 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/err/PostgreSQLErrPacketFactory.java
 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/err/PostgreSQLErrPacketFactory.java
index d3a8faa..c0b807c 100644
--- 
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/err/PostgreSQLErrPacketFactory.java
+++ 
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/err/PostgreSQLErrPacketFactory.java
@@ -23,6 +23,8 @@ import 
org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQ
 import org.postgresql.util.PSQLException;
 import org.postgresql.util.ServerErrorMessage;
 
+import java.sql.SQLException;
+
 /**
  * ERR packet factory for PostgreSQL.
  */
@@ -36,33 +38,35 @@ public final class PostgreSQLErrPacketFactory {
      * @return instance of PostgreSQL ERR packet
      */
     public static PostgreSQLErrorResponsePacket newInstance(final Exception 
cause) {
-        return cause instanceof PSQLException ? 
createErrorResponsePacket((PSQLException) cause) : 
createErrorResponsePacket(cause);
+        if (cause instanceof PSQLException && null != ((PSQLException) 
cause).getServerErrorMessage()) {
+            return createErrorResponsePacket(((PSQLException) 
cause).getServerErrorMessage());
+        }
+        if (cause instanceof SQLException) {
+            return createErrorResponsePacket((SQLException) cause);
+        }
+        return createErrorResponsePacket(cause);
     }
     
-    private static PostgreSQLErrorResponsePacket 
createErrorResponsePacket(final Exception cause) {
+    private static PostgreSQLErrorResponsePacket 
createErrorResponsePacket(final ServerErrorMessage serverErrorMessage) {
         PostgreSQLErrorResponsePacket result = new 
PostgreSQLErrorResponsePacket();
-        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
cause.getMessage());
-        // TODO add common error code
+        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, 
serverErrorMessage.getSQLState());
+        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
serverErrorMessage.getMessage());
+        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY, 
serverErrorMessage.getSeverity());
+        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_POSITION, 
Integer.toString(serverErrorMessage.getPosition()));
         return result;
     }
     
-    private static PostgreSQLErrorResponsePacket 
createErrorResponsePacket(final PSQLException cause) {
-        return null == cause.getServerErrorMessage() ? 
createErrorResponsePacketWithoutServerErrorMessage(cause) : 
createErrorResponsePacketByServerErrorMessage(cause.getServerErrorMessage());
-    }
-    
-    private static PostgreSQLErrorResponsePacket 
createErrorResponsePacketWithoutServerErrorMessage(final PSQLException cause) {
+    private static PostgreSQLErrorResponsePacket 
createErrorResponsePacket(final SQLException cause) {
         PostgreSQLErrorResponsePacket result = new 
PostgreSQLErrorResponsePacket();
         result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, 
cause.getSQLState());
         result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
cause.getMessage());
         return result;
     }
     
-    private static PostgreSQLErrorResponsePacket 
createErrorResponsePacketByServerErrorMessage(final ServerErrorMessage 
serverErrorMessage) {
+    private static PostgreSQLErrorResponsePacket 
createErrorResponsePacket(final Exception cause) {
         PostgreSQLErrorResponsePacket result = new 
PostgreSQLErrorResponsePacket();
-        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY, 
serverErrorMessage.getSeverity());
-        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, 
serverErrorMessage.getSQLState());
-        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
serverErrorMessage.getMessage());
-        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_POSITION, 
Integer.toString(serverErrorMessage.getPosition()));
+        // TODO add FIELD_TYPE_CODE for common error
+        result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_MESSAGE, 
cause.getMessage());
         return result;
     }
 }

Reply via email to