This is an automated email from the ASF dual-hosted git repository.

zhaojinchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 217df4958f6 Refactor MySQLMultiStatementsHandler logic (#29917)
217df4958f6 is described below

commit 217df4958f6ae362da37336170c817883c14ac13
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Tue Jan 30 19:20:02 2024 +0800

    Refactor MySQLMultiStatementsHandler logic (#29917)
---
 .../text/query/MySQLMultiStatementsHandler.java    | 75 +++++++++++++---------
 1 file changed, 44 insertions(+), 31 deletions(-)

diff --git 
a/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandler.java
 
b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandler.java
index c3a43cfb412..301c3d4000b 100644
--- 
a/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandler.java
+++ 
b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandler.java
@@ -63,7 +63,6 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.LinkedHashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -89,9 +88,7 @@ public final class MySQLMultiStatementsHandler implements 
ProxyBackendHandler {
     
     private final MetaDataContexts metaDataContexts = 
ProxyContext.getInstance().getContextManager().getMetaDataContexts();
     
-    private final Map<String, List<ExecutionUnit>> dataSourcesToExecutionUnits 
= new HashMap<>();
-    
-    private final Map<String, ExecutionContext> multiSQLExecutionContexts = 
new LinkedHashMap<>();
+    private final Collection<QueryContext> multiSQLQueryContexts = new 
LinkedList<>();
     
     public MySQLMultiStatementsHandler(final ConnectionSession 
connectionSession, final SQLStatement sqlStatementSample, final String sql) {
         jdbcExecutor = new 
JDBCExecutor(BackendExecutorContext.getInstance().getExecutorEngine(), 
connectionSession.getConnectionContext());
@@ -102,15 +99,7 @@ public final class MySQLMultiStatementsHandler implements 
ProxyBackendHandler {
         SQLParserEngine sqlParserEngine = getSQLParserEngine();
         for (String each : extractMultiStatements(pattern, sql)) {
             SQLStatement eachSQLStatement = sqlParserEngine.parse(each, false);
-            QueryContext queryContext = createQueryContext(each, 
eachSQLStatement);
-            if (null == connectionSession.getQueryContext()) {
-                connectionSession.setQueryContext(queryContext);
-            }
-            ExecutionContext executionContext = 
createExecutionContext(queryContext);
-            multiSQLExecutionContexts.putIfAbsent(each, executionContext);
-            for (ExecutionUnit eachExecutionUnit : 
executionContext.getExecutionUnits()) {
-                
dataSourcesToExecutionUnits.computeIfAbsent(eachExecutionUnit.getDataSourceName(),
 unused -> new LinkedList<>()).add(eachExecutionUnit);
-            }
+            multiSQLQueryContexts.add(createQueryContext(each, 
eachSQLStatement));
         }
     }
     
@@ -130,13 +119,6 @@ public final class MySQLMultiStatementsHandler implements 
ProxyBackendHandler {
         return new QueryContext(sqlStatementContext, sql, 
Collections.emptyList());
     }
     
-    private ExecutionContext createExecutionContext(final QueryContext 
queryContext) {
-        RuleMetaData globalRuleMetaData = 
metaDataContexts.getMetaData().getGlobalRuleMetaData();
-        ShardingSphereDatabase currentDatabase = 
metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName());
-        SQLAuditEngine.audit(queryContext.getSqlStatementContext(), 
queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, 
queryContext.getHintValueContext());
-        return kernelProcessor.generateExecutionContext(queryContext, 
currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), 
connectionSession.getConnectionContext());
-    }
-    
     @Override
     public ResponseHeader execute() throws SQLException {
         Collection<ShardingSphereRule> rules = 
metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName()).getRuleMetaData().getRules();
@@ -144,18 +126,49 @@ public final class MySQLMultiStatementsHandler implements 
ProxyBackendHandler {
                 
.<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY), 
connectionSession.getDatabaseConnectionManager(),
                 (JDBCBackendStatement) 
connectionSession.getStatementManager(), new StatementOption(false), rules,
                 
metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName()).getResourceMetaData().getStorageUnits());
-        ExecutionContext sampleExecutionContext = 
multiSQLExecutionContexts.values().iterator().next();
-        ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = 
prepareEngine.prepare(sampleExecutionContext.getRouteContext(), 
samplingExecutionUnit(),
-                new 
ExecutionGroupReportContext(connectionSession.getProcessId(), 
connectionSession.getDatabaseName(), connectionSession.getGrantee()));
-        for (ExecutionGroup<JDBCExecutionUnit> eachGroup : 
executionGroupContext.getInputGroups()) {
-            for (JDBCExecutionUnit each : eachGroup.getInputs()) {
-                prepareBatchedStatement(each);
+        return executeMultiStatements(prepareEngine);
+    }
+    
+    private UpdateResponseHeader executeMultiStatements(final 
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine) 
throws SQLException {
+        Collection<ExecutionContext> executionContexts = 
createExecutionContexts();
+        Map<String, List<ExecutionUnit>> dataSourcesToExecutionUnits = 
buildDataSourcesToExecutionUnits(executionContexts);
+        ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext =
+                
prepareEngine.prepare(executionContexts.iterator().next().getRouteContext(), 
samplingExecutionUnit(dataSourcesToExecutionUnits),
+                        new 
ExecutionGroupReportContext(connectionSession.getProcessId(), 
connectionSession.getDatabaseName(), connectionSession.getGrantee()));
+        for (ExecutionGroup<JDBCExecutionUnit> each : 
executionGroupContext.getInputGroups()) {
+            for (JDBCExecutionUnit unit : each.getInputs()) {
+                prepareBatchedStatement(unit, dataSourcesToExecutionUnits);
             }
         }
         return executeBatchedStatements(executionGroupContext);
     }
     
-    private Collection<ExecutionUnit> samplingExecutionUnit() {
+    private Collection<ExecutionContext> createExecutionContexts() {
+        Collection<ExecutionContext> result = new LinkedList<>();
+        for (QueryContext each : multiSQLQueryContexts) {
+            result.add(createExecutionContext(each));
+        }
+        return result;
+    }
+    
+    private Map<String, List<ExecutionUnit>> 
buildDataSourcesToExecutionUnits(final Collection<ExecutionContext> 
executionContexts) {
+        Map<String, List<ExecutionUnit>> result = new HashMap<>();
+        for (ExecutionContext each : executionContexts) {
+            for (ExecutionUnit executionUnit : each.getExecutionUnits()) {
+                result.computeIfAbsent(executionUnit.getDataSourceName(), 
unused -> new LinkedList<>()).add(executionUnit);
+            }
+        }
+        return result;
+    }
+    
+    private ExecutionContext createExecutionContext(final QueryContext 
queryContext) {
+        RuleMetaData globalRuleMetaData = 
metaDataContexts.getMetaData().getGlobalRuleMetaData();
+        ShardingSphereDatabase currentDatabase = 
metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName());
+        SQLAuditEngine.audit(queryContext.getSqlStatementContext(), 
queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, 
queryContext.getHintValueContext());
+        return kernelProcessor.generateExecutionContext(queryContext, 
currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), 
connectionSession.getConnectionContext());
+    }
+    
+    private Collection<ExecutionUnit> samplingExecutionUnit(final Map<String, 
List<ExecutionUnit>> dataSourcesToExecutionUnits) {
         Collection<ExecutionUnit> result = new LinkedList<>();
         for (List<ExecutionUnit> each : dataSourcesToExecutionUnits.values()) {
             result.add(each.get(0));
@@ -163,10 +176,10 @@ public final class MySQLMultiStatementsHandler implements 
ProxyBackendHandler {
         return result;
     }
     
-    private void prepareBatchedStatement(final JDBCExecutionUnit each) throws 
SQLException {
-        Statement statement = each.getStorageResource();
-        for (ExecutionUnit eachExecutionUnit : 
dataSourcesToExecutionUnits.get(each.getExecutionUnit().getDataSourceName())) {
-            statement.addBatch(eachExecutionUnit.getSqlUnit().getSql());
+    private void prepareBatchedStatement(final JDBCExecutionUnit 
executionUnit, final Map<String, List<ExecutionUnit>> 
dataSourcesToExecutionUnits) throws SQLException {
+        Statement statement = executionUnit.getStorageResource();
+        for (ExecutionUnit each : 
dataSourcesToExecutionUnits.get(executionUnit.getExecutionUnit().getDataSourceName()))
 {
+            statement.addBatch(each.getSqlUnit().getSql());
         }
     }
     

Reply via email to