This is an automated email from the ASF dual-hosted git repository.

zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 30125f2  Implement SQLServerPrivilegeHandler and add unit test (#10166)
30125f2 is described below

commit 30125f2280384df0fcf5bce299d62227c16a5c6d
Author: Raigor <[email protected]>
AuthorDate: Sat Apr 24 22:31:01 2021 +0800

    Implement SQLServerPrivilegeHandler and add unit test (#10166)
    
    * Fixes #10102
    
    * Fixes test cases.
---
 .../dialect/PostgreSQLPrivilegeHandler.java        |  4 +-
 .../builder/dialect/SQLServerPrivilegeHandler.java | 68 +++++++++++++--
 ...est.java => SQLServerPrivilegeHandlerTest.java} | 96 +++++++++++++++++++++-
 3 files changed, 154 insertions(+), 14 deletions(-)

diff --git 
a/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/main/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/PostgreSQLPrivilegeHandler.java
 
b/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/main/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/PostgreSQLPrivilegeHandler.java
index 5bb6561..09c2993 100644
--- 
a/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/main/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/PostgreSQLPrivilegeHandler.java
+++ 
b/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/main/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/PostgreSQLPrivilegeHandler.java
@@ -77,8 +77,8 @@ public final class PostgreSQLPrivilegeHandler implements 
StoragePrivilegeHandler
         }
     }
     
-    private String getCreateUsersSQL(final ShardingSphereUser users) {
-        return String.format(CREATE_USER_SQL, users.getGrantee(), 
users.getPassword());
+    private String getCreateUsersSQL(final ShardingSphereUser user) {
+        return String.format(CREATE_USER_SQL, user.getGrantee(), 
user.getPassword());
     }
     
     @Override
diff --git 
a/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/main/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeHandler.java
 
b/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/main/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeHandler.java
index 596365d..17041bf 100644
--- 
a/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/main/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeHandler.java
+++ 
b/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/main/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeHandler.java
@@ -31,7 +31,6 @@ import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.sql.Statement;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.LinkedList;
 import java.util.ArrayList;
@@ -46,6 +45,18 @@ import java.util.stream.Collectors;
  */
 public final class SQLServerPrivilegeHandler implements 
StoragePrivilegeHandler {
     
+    private static final String QUESTION_MARK = "?";
+    
+    private static final String LINE_BREAK = "\n";
+    
+    private static final String GO_SQL = "GO";
+    
+    private static final String CREATE_LOGIN_USER_SQL = "CREATE LOGIN %s WITH 
PASSWORD = '%s';";
+    
+    private static final String CREATE_DATABASE_USER_SQL = "CREATE USER %s FOR 
LOGIN %s;";
+    
+    private static final String GRANT_ALL_SQL = "GRANT CONTROL ON DATABASE::%s 
TO %s";
+    
     private static final String GLOBAL_PRIVILEGE_SQL =
             "SELECT pr.name AS GRANTEE, pe.state_desc AS STATE, 
pe.permission_name AS PRIVILEGE_TYPE"
                     + "FROM sys.server_principals AS pr JOIN 
sys.server_permissions AS pe"
@@ -61,16 +72,56 @@ public final class SQLServerPrivilegeHandler implements 
StoragePrivilegeHandler
             "SELECT GRANTOR, GRANTEE, TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, 
PRIVILEGE_TYPE, IS_GRANTABLE from INFORMATION_SCHEMA.TABLE_PRIVILEGES WHERE 
GRANTEE IN (%s)";
     
     @Override
-    public Collection<ShardingSphereUser> diff(final 
Collection<ShardingSphereUser> users, final DataSource dataSource) {
-        return Collections.emptyList();
+    public Collection<ShardingSphereUser> diff(final 
Collection<ShardingSphereUser> users, final DataSource dataSource) throws 
SQLException {
+        Collection<Grantee> grantees = new LinkedList<>();
+        try (Connection connection = dataSource.getConnection()) {
+            Statement statement = connection.createStatement();
+            try (ResultSet resultSet = 
statement.executeQuery(getGlobalPrivilegesSQL(users))) {
+                while (resultSet.next()) {
+                    grantees.add(new Grantee(resultSet.getString("GRANTEE"), 
""));
+                }
+            }
+        }
+        return users.stream().filter(each -> 
!grantees.contains(each.getGrantee())).collect(Collectors.toList());
     }
     
     @Override
-    public void create(final Collection<ShardingSphereUser> users, final 
DataSource dataSource) {
+    public void create(final Collection<ShardingSphereUser> users, final 
DataSource dataSource) throws SQLException {
+        try (Connection connection = dataSource.getConnection(); Statement 
statement = connection.createStatement()) {
+            for (ShardingSphereUser each : users) {
+                statement.execute(getCreateUsersSQL(each));
+            }
+        }
+    }
+    
+    private String getCreateUsersSQL(final ShardingSphereUser user) {
+        StringBuilder result = new StringBuilder();
+        result.append(String.format(CREATE_LOGIN_USER_SQL, 
user.getGrantee().getUsername(), user.getPassword())).append(LINE_BREAK);
+        result.append(GO_SQL).append(LINE_BREAK);
+        result.append(String.format(CREATE_DATABASE_USER_SQL, 
user.getGrantee().getUsername(), 
user.getGrantee().getUsername())).append(LINE_BREAK);
+        result.append(GO_SQL);
+        return result.toString();
     }
     
     @Override
-    public void grantAll(final Collection<ShardingSphereUser> users, final 
DataSource dataSource) {
+    public void grantAll(final Collection<ShardingSphereUser> users, final 
DataSource dataSource) throws SQLException {
+        try (Connection connection = dataSource.getConnection(); Statement 
statement = connection.createStatement()) {
+            String databaseName = 
getDatabaseName(connection.getMetaData().getURL());
+            for (ShardingSphereUser each : users) {
+                statement.execute(getGrantAllSQL(databaseName, each));
+            }
+        }
+    }
+    
+    private String getDatabaseName(final String url) {
+        if (url.contains(QUESTION_MARK)) {
+            return url.substring(url.indexOf("DatabaseName=") + 1, 
url.indexOf("?"));
+        }
+        return url.substring(url.indexOf("DatabaseName=") + 1);
+    }
+    
+    private String getGrantAllSQL(final String databaseName, final 
ShardingSphereUser user) {
+        return String.format(GRANT_ALL_SQL, databaseName, 
user.getGrantee().getUsername());
     }
     
     @Override
@@ -172,13 +223,12 @@ public final class SQLServerPrivilegeHandler implements 
StoragePrivilegeHandler
     }
     
     private String getGlobalPrivilegesSQL(final Collection<ShardingSphereUser> 
users) {
-        String userHostTuples = users.stream().map(each -> 
String.format("'%s'", 
each.getGrantee().getUsername())).collect(Collectors.joining(","));
-        return String.format(GLOBAL_PRIVILEGE_SQL, userHostTuples);
+        String userList = users.stream().map(each -> String.format("'%s'", 
each.getGrantee().getUsername())).collect(Collectors.joining(", "));
+        return String.format(GLOBAL_PRIVILEGE_SQL, userList);
     }
     
     private String getSchemaPrivilegesSQL(final Collection<ShardingSphereUser> 
users) {
-        String userList = users.stream().map(each -> String.format("'%s'", 
each.getGrantee().getUsername()))
-                .collect(Collectors.joining(","));
+        String userList = users.stream().map(each -> String.format("'%s'", 
each.getGrantee().getUsername())).collect(Collectors.joining(", "));
         return String.format(SCHEMA_PRIVILEGE_SQL, userList);
     }
     
diff --git 
a/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/test/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeLoaderTest.java
 
b/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/test/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeHandlerTest.java
similarity index 57%
rename from 
shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/test/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeLoaderTest.java
rename to 
shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/test/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeHandlerTest.java
index 98a3d1d..e23fab2 100644
--- 
a/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/test/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeLoaderTest.java
+++ 
b/shardingsphere-infra/shardingsphere-infra-authority/shardingsphere-infra-authority-common/src/test/java/org/apache/shardingsphere/authority/provider/natived/builder/dialect/SQLServerPrivilegeHandlerTest.java
@@ -28,8 +28,11 @@ import org.junit.BeforeClass;
 import org.junit.Test;
 
 import javax.sql.DataSource;
+import java.sql.Connection;
 import java.sql.ResultSet;
 import java.sql.SQLException;
+import java.sql.Statement;
+import java.sql.DatabaseMetaData;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.LinkedList;
@@ -43,8 +46,9 @@ import static org.junit.Assert.assertThat;
 import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.verify;
 
-public final class SQLServerPrivilegeLoaderTest {
+public final class SQLServerPrivilegeHandlerTest {
     
     @BeforeClass
     public static void setUp() {
@@ -52,6 +56,30 @@ public final class SQLServerPrivilegeLoaderTest {
     }
     
     @Test
+    public void assertDiff() throws SQLException {
+        Collection<ShardingSphereUser> newUsers = createUsers();
+        newUsers.add(new ShardingSphereUser("testUser2", "", ""));
+        DataSource dataSource = mockDataSourceForUsers(newUsers);
+        
assertDiffUsers(TypedSPIRegistry.getRegisteredService(StoragePrivilegeHandler.class,
 "SQLServer", new Properties()).diff(newUsers, dataSource));
+    }
+    
+    @Test
+    public void assertCreate() throws SQLException {
+        Collection<ShardingSphereUser> users = createUsers();
+        DataSource dataSource = mockDataSourceForUsers(users);
+        TypedSPIRegistry.getRegisteredService(StoragePrivilegeHandler.class, 
"SQLServer", new Properties()).create(users, dataSource);
+        assertCreateUsers(users, dataSource.getConnection().createStatement());
+    }
+    
+    @Test
+    public void assertGrantAll() throws SQLException {
+        Collection<ShardingSphereUser> users = createUsers();
+        DataSource dataSource = mockDataSourceForUsers(users);
+        TypedSPIRegistry.getRegisteredService(StoragePrivilegeHandler.class, 
"SQLServer", new Properties()).grantAll(users, dataSource);
+        assertGrantUsersAll(users, 
dataSource.getConnection().createStatement());
+    }
+    
+    @Test
     public void assertLoad() throws SQLException {
         Collection<ShardingSphereUser> users = createUsers();
         DataSource dataSource = mockDataSource(users);
@@ -59,7 +87,7 @@ public final class SQLServerPrivilegeLoaderTest {
     }
     
     private void assertPrivileges(final Map<ShardingSphereUser, 
NativePrivileges> actual) {
-        assertThat(actual.size(), is(1));
+        assertThat(actual.size(), is(2));
         ShardingSphereUser dbo = new ShardingSphereUser("dbo", "", "");
         
assertThat(actual.get(dbo).getAdministrativePrivileges().getPrivileges().size(),
 is(2));
         Collection<PrivilegeType> expectedAdminPrivileges = new 
CopyOnWriteArraySet<>(Arrays.asList(PrivilegeType.CONNECT, 
PrivilegeType.SHUTDOWN));
@@ -69,11 +97,17 @@ public final class SQLServerPrivilegeLoaderTest {
                 PrivilegeType.DELETE));
         SchemaPrivileges schemaPrivileges = 
actual.get(dbo).getDatabasePrivileges().getSpecificPrivileges().get("db0");
         
assertThat(schemaPrivileges.getSpecificPrivileges().get("t_order").hasPrivileges(expectedSpecificPrivilege),
 is(true));
+        
+        ShardingSphereUser testUser = new ShardingSphereUser("testUser", "", 
"");
+        
assertThat(actual.get(testUser).getAdministrativePrivileges().getPrivileges().size(),
 is(0));
+        
assertThat(actual.get(testUser).getDatabasePrivileges().getGlobalPrivileges().size(),
 is(0));
+        
assertThat(actual.get(testUser).getDatabasePrivileges().getSpecificPrivileges().size(),
 is(0));
     }
     
     private Collection<ShardingSphereUser> createUsers() {
         LinkedList<ShardingSphereUser> result = new LinkedList<>();
-        result.add(new ShardingSphereUser("dbo", "", ""));
+        result.add(new ShardingSphereUser("dbo", "password", ""));
+        result.add(new ShardingSphereUser("testUser", "password", ""));
         return result;
     }
     
@@ -97,6 +131,32 @@ public final class SQLServerPrivilegeLoaderTest {
         return result;
     }
     
+    private DataSource mockDataSourceForUsers(final 
Collection<ShardingSphereUser> users) throws SQLException {
+        ResultSet usersResultSet = mockUsersResultSet();
+        DataSource result = mock(DataSource.class, RETURNS_DEEP_STUBS);
+        Statement statement = mock(Statement.class);
+        Connection connection = mock(Connection.class);
+        DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
+        String diffUsersSQL = "SELECT pr.name AS GRANTEE, pe.state_desc AS 
STATE, pe.permission_name AS PRIVILEGE_TYPE"
+                + "FROM sys.server_principals AS pr JOIN 
sys.server_permissions AS pe"
+                + "ON pe.grantee_principal_id = pr.principal_id WHERE pr.name 
IN (%s) GROUP BY pr.name, pe.state_desc, pe.permission_name";
+        String userList = users.stream().map(item -> String.format("'%s'", 
item.getGrantee().getUsername())).collect(Collectors.joining(", "));
+        when(statement.executeQuery(String.format(diffUsersSQL, 
userList))).thenReturn(usersResultSet);
+        when(connection.createStatement()).thenReturn(statement);
+        when(result.getConnection()).thenReturn(connection);
+        when(statement.getConnection()).thenReturn(connection);
+        when(connection.getMetaData()).thenReturn(databaseMetaData);
+        
when(databaseMetaData.getURL()).thenReturn("jdbc:sqlserver://127.0.0.1;DatabaseName=ds_0");
+        return result;
+    }
+    
+    private ResultSet mockUsersResultSet() throws SQLException {
+        ResultSet result = mock(ResultSet.class);
+        when(result.next()).thenReturn(true, true, false);
+        when(result.getString("GRANTEE")).thenReturn("dbo", "testUser");
+        return result;
+    }
+    
     private ResultSet mockGlobalPrivilegeResultSet() throws SQLException {
         ResultSet result = mock(ResultSet.class, RETURNS_DEEP_STUBS);
         when(result.next()).thenReturn(true, true, false);
@@ -126,4 +186,34 @@ public final class SQLServerPrivilegeLoaderTest {
         when(result.getString("DB")).thenReturn("t_order");
         return result;
     }
+    
+    private void assertDiffUsers(final Collection<ShardingSphereUser> users) {
+        assertThat(users.size(), is(1));
+        assertThat(users.iterator().next().getGrantee().getUsername(), 
is("testUser2"));
+    }
+    
+    private void assertCreateUsers(final Collection<ShardingSphereUser> users, 
final Statement statement) throws SQLException {
+        for (ShardingSphereUser each : users) {
+            StringBuilder result = new StringBuilder();
+            result.append(String.format("CREATE LOGIN %s WITH PASSWORD = 
'%s';", each.getGrantee().getUsername(), each.getPassword())).append("\n");
+            result.append("GO").append("\n");
+            result.append(String.format("CREATE USER %s FOR LOGIN %s;\n", 
each.getGrantee().getUsername(), each.getGrantee().getUsername()));
+            result.append("GO");
+            verify(statement).execute(result.toString());
+        }
+    }
+    
+    private void assertGrantUsersAll(final Collection<ShardingSphereUser> 
users, final Statement statement) throws SQLException {
+        String databaseName = 
getDatabaseName(statement.getConnection().getMetaData().getURL());
+        for (ShardingSphereUser each : users) {
+            verify(statement).execute(String.format("GRANT CONTROL ON 
DATABASE::%s TO %s", databaseName, each.getGrantee().getUsername()));
+        }
+    }
+    
+    private String getDatabaseName(final String url) {
+        if (url.contains("?")) {
+            return url.substring(url.indexOf("DatabaseName=") + 1, 
url.indexOf("?"));
+        }
+        return url.substring(url.indexOf("DatabaseName=") + 1);
+    }
 }

Reply via email to