This is an automated email from the ASF dual-hosted git repository.
menghaoranss pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 6b32a2c52a0 Refactor metadata database lookup to avoid duplicate
contains/get access (#38774)
6b32a2c52a0 is described below
commit 6b32a2c52a0dcda2cd4027dcb1c6a7beda909d3b
Author: Haoran Meng <[email protected]>
AuthorDate: Wed Jun 3 08:39:25 2026 +0800
Refactor metadata database lookup to avoid duplicate contains/get access
(#38774)
* Refactor metadata database lookup to avoid duplicate contains/get access
* Add more unit tests
---
.../infra/metadata/ShardingSphereMetaData.java | 5 +-
.../infra/metadata/identifier/IdentifierIndex.java | 71 +++++++++++++++-------
.../metadata/identifier/IdentifierIndexTest.java | 29 +++++++++
.../infra/session/query/QueryContext.java | 5 +-
.../infra/session/query/QueryContextTest.java | 25 +++++++-
.../mode/manager/ContextManager.java | 5 +-
.../mode/manager/ContextManagerTest.java | 3 +-
.../backend/handler/ProxySQLComQueryParser.java | 9 ++-
.../handler/ProxyBackendHandlerFactoryTest.java | 6 +-
.../handler/ProxySQLComQueryParserTest.java | 15 ++++-
10 files changed, 131 insertions(+), 42 deletions(-)
diff --git
a/infra/common/src/main/java/org/apache/shardingsphere/infra/metadata/ShardingSphereMetaData.java
b/infra/common/src/main/java/org/apache/shardingsphere/infra/metadata/ShardingSphereMetaData.java
index e750aa8a05a..2e02f01f507 100644
---
a/infra/common/src/main/java/org/apache/shardingsphere/infra/metadata/ShardingSphereMetaData.java
+++
b/infra/common/src/main/java/org/apache/shardingsphere/infra/metadata/ShardingSphereMetaData.java
@@ -20,7 +20,6 @@ package org.apache.shardingsphere.infra.metadata;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.SneakyThrows;
-import
org.apache.shardingsphere.database.connector.core.metadata.database.enums.QuoteCharacter;
import
org.apache.shardingsphere.database.connector.core.metadata.identifier.IdentifierScope;
import org.apache.shardingsphere.database.connector.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.DatabaseTypeEngine;
@@ -136,7 +135,7 @@ public final class ShardingSphereMetaData implements
AutoCloseable {
* @return contains database from meta data or not
*/
public boolean containsDatabase(final String databaseName) {
- return containsDatabase(new IdentifierValue(databaseName,
QuoteCharacter.NONE));
+ return databaseIndex.contains(databaseName);
}
/**
@@ -156,7 +155,7 @@ public final class ShardingSphereMetaData implements
AutoCloseable {
* @return meta data database
*/
public ShardingSphereDatabase getDatabase(final String databaseName) {
- return getDatabase(new IdentifierValue(databaseName,
QuoteCharacter.NONE));
+ return databaseIndex.get(databaseName);
}
/**
diff --git
a/infra/common/src/main/java/org/apache/shardingsphere/infra/metadata/identifier/IdentifierIndex.java
b/infra/common/src/main/java/org/apache/shardingsphere/infra/metadata/identifier/IdentifierIndex.java
index 568f2e59658..9f30e1881d5 100644
---
a/infra/common/src/main/java/org/apache/shardingsphere/infra/metadata/identifier/IdentifierIndex.java
+++
b/infra/common/src/main/java/org/apache/shardingsphere/infra/metadata/identifier/IdentifierIndex.java
@@ -144,6 +144,32 @@ public final class IdentifierIndex<T> {
return snapshot.getExactValues().size();
}
+ /**
+ * Judge whether contains metadata object by unquoted identifier or not.
+ *
+ * @param identifier unquoted identifier
+ * @return contains metadata object by unquoted identifier or not
+ */
+ public boolean contains(final String identifier) {
+ return null != get(identifier);
+ }
+
+ /**
+ * Get metadata object by unquoted identifier.
+ *
+ * @param identifier unquoted identifier
+ * @return matched metadata object
+ */
+ public T get(final String identifier) {
+ Objects.requireNonNull(identifier, "identifier cannot be null.");
+ Snapshot<T> currentSnapshot = snapshot;
+ IdentifierCaseRule rule =
databaseIdentifierContext.getRule(identifierScope);
+ if (LookupMode.EXACT == rule.getLookupMode(QuoteCharacter.NONE)) {
+ return currentSnapshot.getExactValues().get(identifier);
+ }
+ return getByNormalizedIdentifier(currentSnapshot, rule, identifier);
+ }
+
/**
* Find metadata object by identifier value.
*
@@ -157,35 +183,38 @@ public final class IdentifierIndex<T> {
if (LookupMode.EXACT ==
rule.getLookupMode(identifierValue.getQuoteCharacter())) {
return
Optional.ofNullable(currentSnapshot.getExactValues().get(identifierValue.getValue()));
}
- return findByNormalizedIdentifier(currentSnapshot, rule,
identifierValue);
+ return QuoteCharacter.NONE == identifierValue.getQuoteCharacter()
+ ?
Optional.ofNullable(getByNormalizedIdentifier(currentSnapshot, rule,
identifierValue.getValue()))
+ : findByQuotedNormalizedIdentifier(currentSnapshot, rule,
identifierValue);
}
- private Optional<T> findByNormalizedIdentifier(final Snapshot<T>
currentSnapshot, final IdentifierCaseRule rule, final IdentifierValue
identifierValue) {
- NormalizedBucket<T> normalizedBucket =
currentSnapshot.getNormalizedBuckets().get(rule.normalize(identifierValue.getValue()));
+ private T getByNormalizedIdentifier(final Snapshot<T> currentSnapshot,
final IdentifierCaseRule rule, final String identifier) {
+ NormalizedBucket<T> normalizedBucket =
currentSnapshot.getNormalizedBuckets().get(rule.normalize(identifier));
if (null == normalizedBucket) {
- return Optional.empty();
+ return null;
}
- return QuoteCharacter.NONE == identifierValue.getQuoteCharacter()
- ? findByUnquotedNormalizedIdentifier(currentSnapshot,
normalizedBucket, identifierValue.getValue())
- : findByQuotedNormalizedIdentifier(currentSnapshot, rule,
normalizedBucket, identifierValue);
+ return getByUnquotedNormalizedIdentifier(currentSnapshot,
normalizedBucket, identifier);
}
- private Optional<T> findByUnquotedNormalizedIdentifier(final Snapshot<T>
currentSnapshot, final NormalizedBucket<T> normalizedBucket, final String
identifierValue) {
+ private T getByUnquotedNormalizedIdentifier(final Snapshot<T>
currentSnapshot, final NormalizedBucket<T> normalizedBucket, final String
identifier) {
if (!normalizedBucket.hasUnquotedIdentifier()) {
- return Optional.empty();
+ return null;
}
if (normalizedBucket.hasSingleUnquotedIdentifier()) {
- return
Optional.ofNullable(normalizedBucket.getSingleUnquotedValue());
+ return normalizedBucket.getSingleUnquotedValue();
}
- Optional<T> exactMatchedValue =
findExactMatchedValue(currentSnapshot.getExactValues(),
normalizedBucket.getUnquotedIdentifiers(), identifierValue);
- if (exactMatchedValue.isPresent()) {
+ T exactMatchedValue =
findExactMatchedValue(currentSnapshot.getExactValues(),
normalizedBucket.getUnquotedIdentifiers(), identifier);
+ if (null != exactMatchedValue) {
return exactMatchedValue;
}
- throw new AmbiguousIdentifierException(identifierValue,
normalizedBucket.getUnquotedIdentifiers());
+ throw new AmbiguousIdentifierException(identifier,
normalizedBucket.getUnquotedIdentifiers());
}
- private Optional<T> findByQuotedNormalizedIdentifier(final Snapshot<T>
currentSnapshot, final IdentifierCaseRule rule,
- final
NormalizedBucket<T> normalizedBucket, final IdentifierValue identifierValue) {
+ private Optional<T> findByQuotedNormalizedIdentifier(final Snapshot<T>
currentSnapshot, final IdentifierCaseRule rule, final IdentifierValue
identifierValue) {
+ NormalizedBucket<T> normalizedBucket =
currentSnapshot.getNormalizedBuckets().get(rule.normalize(identifierValue.getValue()));
+ if (null == normalizedBucket) {
+ return Optional.empty();
+ }
if (normalizedBucket.hasSingleIdentifier()) {
return rule.matches(normalizedBucket.getSingleIdentifier(),
identifierValue.getValue(), identifierValue.getQuoteCharacter())
? Optional.ofNullable(normalizedBucket.getSingleValue())
@@ -213,19 +242,19 @@ public final class IdentifierIndex<T> {
if (null == ambiguousIdentifiers) {
return
Optional.ofNullable(currentSnapshot.getExactValues().get(matchedIdentifier));
}
- Optional<T> exactMatchedValue =
findExactMatchedValue(currentSnapshot.getExactValues(), ambiguousIdentifiers,
identifierValue.getValue());
- if (exactMatchedValue.isPresent()) {
- return exactMatchedValue;
+ T exactMatchedValue =
findExactMatchedValue(currentSnapshot.getExactValues(), ambiguousIdentifiers,
identifierValue.getValue());
+ if (null != exactMatchedValue) {
+ return Optional.of(exactMatchedValue);
}
throw new AmbiguousIdentifierException(identifierValue.getValue(),
ambiguousIdentifiers);
}
- private Optional<T> findExactMatchedValue(final Map<String, T>
exactValues, final Collection<String> matchedIdentifiers, final String
identifierValue) {
+ private T findExactMatchedValue(final Map<String, T> exactValues, final
Collection<String> matchedIdentifiers, final String identifierValue) {
if (!matchedIdentifiers.contains(identifierValue)) {
- return Optional.empty();
+ return null;
}
log.warn("Identifier '{}' matched multiple actual identifiers {}.
Fallback to exact identifier '{}'.", identifierValue, matchedIdentifiers,
identifierValue);
- return Optional.ofNullable(exactValues.get(identifierValue));
+ return exactValues.get(identifierValue);
}
private void addNormalizedIdentifier(final Map<String, Collection<String>>
values, final IdentifierCaseRule rule, final String name) {
diff --git
a/infra/common/src/test/java/org/apache/shardingsphere/infra/metadata/identifier/IdentifierIndexTest.java
b/infra/common/src/test/java/org/apache/shardingsphere/infra/metadata/identifier/IdentifierIndexTest.java
index c03ac009d1a..a331ca3e457 100644
---
a/infra/common/src/test/java/org/apache/shardingsphere/infra/metadata/identifier/IdentifierIndexTest.java
+++
b/infra/common/src/test/java/org/apache/shardingsphere/infra/metadata/identifier/IdentifierIndexTest.java
@@ -101,6 +101,20 @@ class IdentifierIndexTest {
assertThat(actualValue, is(Optional.of("value_1")));
}
+ @Test
+ void assertContainsWithNormalizedLookup() {
+ IdentifierIndex<String> index = new IdentifierIndex<>(new
DatabaseIdentifierContext(new IdentifierCaseRuleSet(createPostgreSQLRule())),
IdentifierScope.TABLE);
+ index.rebuild(createSingleValueMap("foo", "value_1"));
+ assertTrue(index.contains("FOO"));
+ }
+
+ @Test
+ void assertGetWithNormalizedLookup() {
+ IdentifierIndex<String> index = new IdentifierIndex<>(new
DatabaseIdentifierContext(new IdentifierCaseRuleSet(createPostgreSQLRule())),
IdentifierScope.TABLE);
+ index.rebuild(createSingleValueMap("foo", "value_1"));
+ assertThat(index.get("FOO"), is("value_1"));
+ }
+
@Test
void assertFindWithNormalizedLookupIgnoresNonMatchingStoredCase() {
IdentifierIndex<String> index = new IdentifierIndex<>(new
DatabaseIdentifierContext(new IdentifierCaseRuleSet(createPostgreSQLRule())),
IdentifierScope.TABLE);
@@ -128,6 +142,13 @@ class IdentifierIndexTest {
assertThat(actualValue, is(Optional.of("value_2")));
}
+ @Test
+ void
assertGetPrefersExactIdentifierWhenNormalizedLookupMatchesMultipleValues() {
+ IdentifierIndex<String> index = new IdentifierIndex<>(new
DatabaseIdentifierContext(new
IdentifierCaseRuleSet(createMySQLInsensitiveRule())), IdentifierScope.TABLE);
+ index.rebuild(createAmbiguousValueMap());
+ assertThat(index.get("foo"), is("value_2"));
+ }
+
@Test
void assertRemove() {
IdentifierIndex<String> index = new IdentifierIndex<>(new
DatabaseIdentifierContext(new IdentifierCaseRuleSet(createExactRule())),
IdentifierScope.TABLE);
@@ -175,6 +196,14 @@ class IdentifierIndexTest {
assertThat(actualException.getMessage(), is("Identifier 'FOO' is
ambiguous, matched actual identifiers: Foo, foo."));
}
+ @Test
+ void assertGetThrowsAmbiguousIdentifierException() {
+ IdentifierIndex<String> index = new IdentifierIndex<>(new
DatabaseIdentifierContext(new
IdentifierCaseRuleSet(createMySQLInsensitiveRule())), IdentifierScope.TABLE);
+ index.rebuild(createAmbiguousValueMap());
+ AmbiguousIdentifierException actualException =
assertThrows(AmbiguousIdentifierException.class, () -> index.get("FOO"));
+ assertThat(actualException.getMessage(), is("Identifier 'FOO' is
ambiguous, matched actual identifiers: Foo, foo."));
+ }
+
private IdentifierCaseRule createExactRule() {
return new StandardIdentifierCaseRule(LookupMode.EXACT,
LookupMode.EXACT, each -> each, each -> true);
}
diff --git
a/infra/session/src/main/java/org/apache/shardingsphere/infra/session/query/QueryContext.java
b/infra/session/src/main/java/org/apache/shardingsphere/infra/session/query/QueryContext.java
index 3d780c5a85f..d3b6b265f97 100644
---
a/infra/session/src/main/java/org/apache/shardingsphere/infra/session/query/QueryContext.java
+++
b/infra/session/src/main/java/org/apache/shardingsphere/infra/session/query/QueryContext.java
@@ -133,7 +133,8 @@ public final class QueryContext {
private ShardingSphereDatabase getDatabase(final String databaseName) {
ShardingSpherePreconditions.checkNotNull(databaseName,
NoDatabaseSelectedException::new);
-
ShardingSpherePreconditions.checkState(metaData.containsDatabase(databaseName),
() -> new UnknownDatabaseException(databaseName));
- return metaData.getDatabase(databaseName);
+ ShardingSphereDatabase result = metaData.getDatabase(databaseName);
+ ShardingSpherePreconditions.checkState(null != result, () -> new
UnknownDatabaseException(databaseName));
+ return result;
}
}
diff --git
a/infra/session/src/test/java/org/apache/shardingsphere/infra/session/query/QueryContextTest.java
b/infra/session/src/test/java/org/apache/shardingsphere/infra/session/query/QueryContextTest.java
index 42d3af8d9f9..cfec2daf605 100644
---
a/infra/session/src/test/java/org/apache/shardingsphere/infra/session/query/QueryContextTest.java
+++
b/infra/session/src/test/java/org/apache/shardingsphere/infra/session/query/QueryContextTest.java
@@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.Optional;
+import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
@@ -45,7 +46,6 @@ class QueryContextTest {
when(connectionContext.getCurrentDatabaseName()).thenReturn(Optional.of("foo_db"));
ShardingSphereMetaData metaData = mock(ShardingSphereMetaData.class);
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class);
- when(metaData.containsDatabase("foo_db")).thenReturn(true);
when(metaData.getDatabase("foo_db")).thenReturn(database);
SQLStatementContext sqlStatementContext =
mock(SQLStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getTablesContext().getDatabaseNames()).thenReturn(Collections.emptyList());
@@ -72,7 +72,6 @@ class QueryContextTest {
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class);
when(database.isComplete()).thenReturn(true);
when(database.getName()).thenReturn("foo_db");
- when(metaData.containsDatabase("foo_db")).thenReturn(true);
when(metaData.getAllDatabases()).thenReturn(Collections.singletonList(database));
when(metaData.getDatabase("foo_db")).thenReturn(database);
SQLStatementContext sqlStatementContext =
mock(SQLStatementContext.class, RETURNS_DEEP_STUBS);
@@ -93,4 +92,26 @@ class QueryContextTest {
QueryContext actual = new QueryContext(sqlStatementContext, "SELECT
1", Collections.emptyList(), mock(HintValueContext.class), connectionContext,
metaData);
assertThrows(UnknownDatabaseException.class, actual::getUsedDatabase);
}
+
+ @Test
+ void assertGetUsedDatabases() {
+ ConnectionContext connectionContext = mock(ConnectionContext.class);
+ ShardingSphereMetaData metaData = mock(ShardingSphereMetaData.class);
+ ShardingSphereDatabase database = mock(ShardingSphereDatabase.class);
+ SQLStatementContext sqlStatementContext =
mock(SQLStatementContext.class, RETURNS_DEEP_STUBS);
+
when(sqlStatementContext.getTablesContext().getDatabaseNames()).thenReturn(Collections.singleton("foo_db"));
+ when(metaData.getDatabase("foo_db")).thenReturn(database);
+ QueryContext actual = new QueryContext(sqlStatementContext, "SELECT
1", Collections.emptyList(), mock(HintValueContext.class), connectionContext,
metaData);
+ assertThat(actual.getUsedDatabases(), contains(database));
+ }
+
+ @Test
+ void assertGetUsedDatabasesWhenUseUnknownDatabase() {
+ ConnectionContext connectionContext = mock(ConnectionContext.class);
+ ShardingSphereMetaData metaData = mock(ShardingSphereMetaData.class);
+ SQLStatementContext sqlStatementContext =
mock(SQLStatementContext.class, RETURNS_DEEP_STUBS);
+
when(sqlStatementContext.getTablesContext().getDatabaseNames()).thenReturn(Collections.singleton("unknown_db"));
+ QueryContext actual = new QueryContext(sqlStatementContext, "SELECT
1", Collections.emptyList(), mock(HintValueContext.class), connectionContext,
metaData);
+ assertThrows(UnknownDatabaseException.class, actual::getUsedDatabases);
+ }
}
diff --git
a/mode/core/src/main/java/org/apache/shardingsphere/mode/manager/ContextManager.java
b/mode/core/src/main/java/org/apache/shardingsphere/mode/manager/ContextManager.java
index ea8ab8f2cd3..4eabfe64fe3 100644
---
a/mode/core/src/main/java/org/apache/shardingsphere/mode/manager/ContextManager.java
+++
b/mode/core/src/main/java/org/apache/shardingsphere/mode/manager/ContextManager.java
@@ -121,8 +121,9 @@ public final class ContextManager implements AutoCloseable {
public ShardingSphereDatabase getDatabase(final String name) {
ShardingSpherePreconditions.checkNotEmpty(name,
NoDatabaseSelectedException::new);
ShardingSphereMetaData metaData = metaDataContexts.getMetaData();
-
ShardingSpherePreconditions.checkState(metaData.containsDatabase(name), () ->
new UnknownDatabaseException(name));
- return metaData.getDatabase(name);
+ ShardingSphereDatabase result = metaData.getDatabase(name);
+ ShardingSpherePreconditions.checkState(null != result, () -> new
UnknownDatabaseException(name));
+ return result;
}
/**
diff --git
a/mode/core/src/test/java/org/apache/shardingsphere/mode/manager/ContextManagerTest.java
b/mode/core/src/test/java/org/apache/shardingsphere/mode/manager/ContextManagerTest.java
index 7a32dd8f910..06ce651ede5 100644
---
a/mode/core/src/test/java/org/apache/shardingsphere/mode/manager/ContextManagerTest.java
+++
b/mode/core/src/test/java/org/apache/shardingsphere/mode/manager/ContextManagerTest.java
@@ -114,8 +114,8 @@ class ContextManagerTest {
when(metaDataContexts.getMetaData().getProps()).thenReturn(new
ConfigurationProperties(new Properties()));
when(metaDataContexts.getMetaData().getTemporaryProps()).thenReturn(new
TemporaryConfigurationProperties(new Properties()));
database = mockDatabase();
-
when(metaDataContexts.getMetaData().containsDatabase("foo_db")).thenReturn(true);
when(metaDataContexts.getMetaData().getDatabase("foo_db")).thenReturn(database);
+
when(metaDataContexts.getMetaData().getDatabase("bar_db")).thenReturn(null);
when(metaDataContexts.getMetaData().getAllDatabases()).thenReturn(Collections.singleton(database));
when(computeNodeInstanceContext.getInstance()).thenReturn(new
ComputeNodeInstance(new ProxyInstanceMetaData("foo_id", 3307),
Collections.emptyList()));
when(computeNodeInstanceContext.getModeConfiguration()).thenReturn(new
ModeConfiguration("FIXTURE", mock()));
@@ -182,7 +182,6 @@ class ContextManagerTest {
ShardingSphereDatabase database =
new ShardingSphereDatabase("foo_db", mock(DatabaseType.class),
resourceMetaData, mock(RuleMetaData.class), Collections.emptyList(), new
ConfigurationProperties(new Properties()));
when(metaDataContexts.getMetaData().getDatabase("foo_db")).thenReturn(database);
-
when(metaDataContexts.getMetaData().containsDatabase("foo_db")).thenReturn(true);
assertThat(contextManager.getStorageUnits("foo_db").size(), is(1));
}
diff --git
a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/ProxySQLComQueryParser.java
b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/ProxySQLComQueryParser.java
index 8bae93b5c8b..33a1387e765 100644
---
a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/ProxySQLComQueryParser.java
+++
b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/ProxySQLComQueryParser.java
@@ -22,6 +22,7 @@ import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.database.connector.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
+import
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.parser.rule.SQLParserRule;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
@@ -55,8 +56,10 @@ public final class ProxySQLComQueryParser {
private static DatabaseType getParserDatabaseType(final DatabaseType
defaultDatabaseType, final ConnectionSession connectionSession) {
String databaseName = connectionSession.getUsedDatabaseName();
- return Strings.isNullOrEmpty(databaseName) ||
!ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().containsDatabase(databaseName)
- ? defaultDatabaseType
- :
ProxyContext.getInstance().getContextManager().getDatabase(databaseName).getProtocolType();
+ if (Strings.isNullOrEmpty(databaseName)) {
+ return defaultDatabaseType;
+ }
+ ShardingSphereDatabase database =
ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabase(databaseName);
+ return null == database ? defaultDatabaseType :
database.getProtocolType();
}
}
diff --git
a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/ProxyBackendHandlerFactoryTest.java
b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/ProxyBackendHandlerFactoryTest.java
index f4fc5644e52..91a8b6fecad 100644
---
a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/ProxyBackendHandlerFactoryTest.java
+++
b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/ProxyBackendHandlerFactoryTest.java
@@ -136,13 +136,11 @@ class ProxyBackendHandlerFactoryTest {
private ContextManager mockContextManager() {
MetaDataContexts metaDataContexts = mock(MetaDataContexts.class,
RETURNS_DEEP_STUBS);
- DatabaseType dbProtocolType = mock(DatabaseType.class);
ResourceMetaData dbResourceMetaData = mock(ResourceMetaData.class);
- ShardingSphereDatabase database = new ShardingSphereDatabase("db",
dbProtocolType, dbResourceMetaData,
+ ShardingSphereDatabase database = new ShardingSphereDatabase("db",
databaseType, dbResourceMetaData,
new RuleMetaData(Collections.emptyList()),
Collections.emptyList(), new ConfigurationProperties(new Properties()));
- DatabaseType informationSchemaProtocolType = mock(DatabaseType.class);
ResourceMetaData informationSchemaResourceMetaData =
mock(ResourceMetaData.class);
- ShardingSphereDatabase informationSchemaDatabase = new
ShardingSphereDatabase("information_schema", informationSchemaProtocolType,
informationSchemaResourceMetaData,
+ ShardingSphereDatabase informationSchemaDatabase = new
ShardingSphereDatabase("information_schema", databaseType,
informationSchemaResourceMetaData,
new RuleMetaData(Collections.emptyList()),
Collections.emptyList(), new ConfigurationProperties(new Properties()));
when(metaDataContexts.getMetaData().getDatabase("db")).thenReturn(database);
when(metaDataContexts.getMetaData().getDatabase("information_schema")).thenReturn(informationSchemaDatabase);
diff --git
a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/ProxySQLComQueryParserTest.java
b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/ProxySQLComQueryParserTest.java
index 3fc7d2ed0e3..dcdc161b8bc 100644
---
a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/ProxySQLComQueryParserTest.java
+++
b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/ProxySQLComQueryParserTest.java
@@ -44,7 +44,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
-import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -65,6 +64,17 @@ class ProxySQLComQueryParserTest {
assertThat(actual.getDatabaseType(), is(databaseType));
}
+ @Test
+ void assertParseWithEmptyDatabaseNameUseDefaultType() {
+ when(connectionSession.getUsedDatabaseName()).thenReturn("");
+ SQLParserEngine parserEngine = mock(SQLParserEngine.class);
+ SQLStatement expected =
UpdateStatement.builder().databaseType(databaseType).build();
+ SQLParserRule parserRule = mockParserRule(parserEngine);
+ mockProxyContext(parserRule, false, null);
+ when(parserEngine.parse(anyString(), eq(false))).thenReturn(expected);
+ assertThat(ProxySQLComQueryParser.parse("select 1", databaseType,
connectionSession), is(expected));
+ }
+
@Test
void assertParseWithMissingDatabaseUseDefaultType() {
when(connectionSession.getUsedDatabaseName()).thenReturn("missing_db");
@@ -97,10 +107,9 @@ class ProxySQLComQueryParserTest {
private void mockProxyContext(final SQLParserRule parserRule, final
boolean containsDatabase, final DatabaseType protocolType) {
ContextManager contextManager = mock(ContextManager.class,
RETURNS_DEEP_STUBS);
when(contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData()).thenReturn(new
RuleMetaData(Collections.singleton(parserRule)));
-
lenient().when(contextManager.getMetaDataContexts().getMetaData().containsDatabase("logic_db")).thenReturn(containsDatabase);
if (containsDatabase) {
ShardingSphereDatabase database =
mock(ShardingSphereDatabase.class);
- when(contextManager.getDatabase(anyString())).thenReturn(database);
+
when(contextManager.getMetaDataContexts().getMetaData().getDatabase("logic_db")).thenReturn(database);
when(database.getProtocolType()).thenReturn(protocolType);
}
when(ProxyContext.getInstance().getContextManager()).thenReturn(contextManager);