This is an automated email from the ASF dual-hosted git repository.
wuweijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new d312e29 Refactor PostgreSQLComStartupPacket (#10091)
d312e29 is described below
commit d312e297b1233740f45455a28a546948a55ce590
Author: Liang Zhang <[email protected]>
AuthorDate: Thu Apr 15 00:53:12 2021 +0800
Refactor PostgreSQLComStartupPacket (#10091)
* Refactor PostgreSQLComStartupPacket
* Refactor PostgreSQLComStartupPacket
* Refactor PostgreSQLAuthenticationEngine
* Refactor PostgreSQLAuthenticationEngine
---
.../handshake/PostgreSQLComStartupPacket.java | 26 +++++++--
.../postgresql/packet/ByteBufTestUtils.java | 6 +--
.../generic/PostgreSQLComStartupPacketTest.java | 61 +++++++++++++++-------
.../PostgreSQLAuthenticationEngine.java | 20 +++----
4 files changed, 74 insertions(+), 39 deletions(-)
diff --git
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
index 7a86107..7fcedd4 100644
---
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
+++
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
@@ -17,7 +17,6 @@
package org.apache.shardingsphere.db.protocol.postgresql.packet.handshake;
-import lombok.Getter;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
import
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
@@ -27,10 +26,13 @@ import java.util.Map;
/**
* Startup packet for PostgreSQL.
*/
-@Getter
public final class PostgreSQLComStartupPacket implements PostgreSQLPacket {
- private final Map<String, String> parametersMap = new HashMap<>(16, 1);
+ private static final String DATABASE_NAME_KEY = "database";
+
+ private static final String USER_NAME_KEY = "user";
+
+ private final Map<String, String> parametersMap = new HashMap<>();
public PostgreSQLComStartupPacket(final PostgreSQLPacketPayload payload) {
payload.skipReserved(8);
@@ -39,6 +41,24 @@ public final class PostgreSQLComStartupPacket implements
PostgreSQLPacket {
}
}
+ /**
+ * Get database.
+ *
+ * @return database
+ */
+ public String getDatabase() {
+ return parametersMap.get(DATABASE_NAME_KEY);
+ }
+
+ /**
+ * Get user.
+ *
+ * @return user
+ */
+ public String getUser() {
+ return parametersMap.get(USER_NAME_KEY);
+ }
+
@Override
public void write(final PostgreSQLPacketPayload payload) {
}
diff --git
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/ByteBufTestUtils.java
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/ByteBufTestUtils.java
index a84bbce..f3249c3 100644
---
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/ByteBufTestUtils.java
+++
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/ByteBufTestUtils.java
@@ -30,7 +30,7 @@ public final class ByteBufTestUtils {
* Creates a new buffer with a newly allocated byte array, fixed capacity.
*
* @param capacity the fixed capacity of the underlying byte array
- * @return ByteBuf
+ * @return byte buffer
*/
public static ByteBuf createByteBuf(final int capacity) {
return createByteBuf(capacity, capacity);
@@ -40,8 +40,8 @@ public final class ByteBufTestUtils {
* Creates a new buffer with a newly allocated byte array.
*
* @param initialCapacity the initial capacity of the underlying byte array
- * @param maxCapacity the max capacity of the underlying byte array
- * @return ByteBuf
+ * @param maxCapacity the max capacity of the underlying byte array
+ * @return byte buffer
*/
public static ByteBuf createByteBuf(final int initialCapacity, final int
maxCapacity) {
UnpooledByteBufAllocator byteBufAllocator =
UnpooledByteBufAllocator.DEFAULT;
diff --git
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLComStartupPacketTest.java
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLComStartupPacketTest.java
index 44130f3..e2aa129 100644
---
a/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLComStartupPacketTest.java
+++
b/shardingsphere-db-protocol/shardingsphere-db-protocol-postgresql/src/test/java/org/apache/shardingsphere/db/protocol/postgresql/packet/generic/PostgreSQLComStartupPacketTest.java
@@ -25,34 +25,55 @@ import org.junit.Test;
import java.util.LinkedHashMap;
import java.util.Map;
+import java.util.Map.Entry;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
+import static org.mockito.Mockito.mock;
public final class PostgreSQLComStartupPacketTest {
@Test
- public void assertReadWrite() {
- Map<String, String> expectedParametersMap = new LinkedHashMap<>();
- expectedParametersMap.put("user", "postgres");
- expectedParametersMap.put("database", "postgres");
- int expectedLength = 4 + 4;
- for (Map.Entry<String, String> each :
expectedParametersMap.entrySet()) {
- expectedLength += each.getKey().length() + 1;
- expectedLength += each.getValue().length() + 1;
+ public void assertNewPostgreSQLComStartupPacket() {
+ Map<String, String> parametersMap = createParametersMap();
+ int packetMessageLength = getPacketMessageLength(parametersMap);
+ ByteBuf byteBuf = ByteBufTestUtils.createByteBuf(packetMessageLength);
+ PostgreSQLPacketPayload payload = createPayload(parametersMap,
packetMessageLength, byteBuf);
+ PostgreSQLComStartupPacket actual = new
PostgreSQLComStartupPacket(payload);
+ assertThat(actual.getDatabase(), is("test_db"));
+ assertThat(actual.getUser(), is("postgres"));
+ assertThat(byteBuf.writerIndex(), is(packetMessageLength));
+ }
+
+ private Map<String, String> createParametersMap() {
+ Map<String, String> result = new LinkedHashMap<>(2, 1);
+ result.put("database", "test_db");
+ result.put("user", "postgres");
+ return result;
+ }
+
+ private int getPacketMessageLength(final Map<String, String>
parametersMap) {
+ int result = 4 + 4;
+ for (Entry<String, String> entry : parametersMap.entrySet()) {
+ result += entry.getKey().length() + 1;
+ result += entry.getValue().length() + 1;
}
- ByteBuf byteBuf = ByteBufTestUtils.createByteBuf(expectedLength);
- PostgreSQLPacketPayload payload = new PostgreSQLPacketPayload(byteBuf);
- payload.writeInt4(expectedLength);
- payload.writeInt4(196608);
- for (Map.Entry<String, String> each :
expectedParametersMap.entrySet()) {
- payload.writeStringNul(each.getKey());
- payload.writeStringNul(each.getValue());
+ return result;
+ }
+
+ private PostgreSQLPacketPayload createPayload(final Map<String, String>
actualParametersMap, final int actualMessageLength, final ByteBuf byteBuf) {
+ PostgreSQLPacketPayload result = new PostgreSQLPacketPayload(byteBuf);
+ result.writeInt4(actualMessageLength);
+ result.writeInt4(196608);
+ for (Entry<String, String> entry : actualParametersMap.entrySet()) {
+ result.writeStringNul(entry.getKey());
+ result.writeStringNul(entry.getValue());
}
- PostgreSQLComStartupPacket packet = new
PostgreSQLComStartupPacket(payload);
- Map<String, String> actualParametersMap = packet.getParametersMap();
- assertThat(actualParametersMap, is(expectedParametersMap));
- packet.write(payload);
- assertThat(byteBuf.writerIndex(), is(expectedLength));
+ return result;
+ }
+
+ @Test
+ public void assertWrite() {
+ new
PostgreSQLComStartupPacket(mock(PostgreSQLPacketPayload.class)).write(mock(PostgreSQLPacketPayload.class));
}
}
diff --git
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
index 4a51655..33950dd 100644
---
a/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
+++
b/shardingsphere-proxy/shardingsphere-proxy-frontend/shardingsphere-proxy-frontend-postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/authentication/PostgreSQLAuthenticationEngine.java
@@ -49,10 +49,6 @@ public final class PostgreSQLAuthenticationEngine implements
AuthenticationEngin
private static final int SSL_REQUEST_CODE = 80877103;
- private static final String USER_NAME_KEYWORD = "user";
-
- private static final String DATABASE_NAME_KEYWORD = "database";
-
private final AtomicBoolean startupMessageReceived = new
AtomicBoolean(false);
private volatile byte[] md5Salt;
@@ -79,23 +75,21 @@ public final class PostgreSQLAuthenticationEngine
implements AuthenticationEngin
private AuthenticationResult beforeStartupMessage(final
ChannelHandlerContext context, final PostgreSQLPacketPayload payload) {
PostgreSQLComStartupPacket comStartupPacket = new
PostgreSQLComStartupPacket(payload);
startupMessageReceived.set(true);
- String databaseName =
comStartupPacket.getParametersMap().get(DATABASE_NAME_KEYWORD);
- if (!Strings.isNullOrEmpty(databaseName) &&
!ProxyContext.getInstance().schemaExists(databaseName)) {
- PostgreSQLErrorResponsePacket responsePacket =
createErrorPacket(PostgreSQLErrorCode.INVALID_CATALOG_NAME,
String.format("database \"%s\" does not exist", databaseName));
- context.writeAndFlush(responsePacket);
+ String database = comStartupPacket.getDatabase();
+ if (!Strings.isNullOrEmpty(database) &&
!ProxyContext.getInstance().schemaExists(database)) {
+
context.writeAndFlush(createErrorPacket(PostgreSQLErrorCode.INVALID_CATALOG_NAME,
String.format("database \"%s\" does not exist", database)));
context.close();
return AuthenticationResultBuilder.continued();
}
- String username =
comStartupPacket.getParametersMap().get(USER_NAME_KEYWORD);
- if (null == username || username.isEmpty()) {
- PostgreSQLErrorResponsePacket responsePacket =
createErrorPacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"user not set in StartupMessage");
- context.writeAndFlush(responsePacket);
+ String user = comStartupPacket.getUser();
+ if (Strings.isNullOrEmpty(user)) {
+
context.writeAndFlush(createErrorPacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"user not set in StartupMessage"));
context.close();
return AuthenticationResultBuilder.continued();
}
md5Salt =
PostgreSQLRandomGenerator.getInstance().generateRandomBytes(4);
context.writeAndFlush(new
PostgreSQLAuthenticationMD5PasswordPacket(md5Salt));
- currentAuthResult = AuthenticationResultBuilder.continued(username,
"", databaseName);
+ currentAuthResult = AuthenticationResultBuilder.continued(user, "",
database);
return currentAuthResult;
}