villebro commented on a change in pull request #10946:
URL: 
https://github.com/apache/incubator-superset/pull/10946#discussion_r492004099



##########
File path: tests/security_tests.py
##########
@@ -1009,88 +1010,143 @@ class TestRowLevelSecurity(SupersetTestCase):
     """
 
     rls_entry = None
+    query_obj = dict(
+        groupby=[],
+        metrics=[],
+        filter=[],
+        is_timeseries=False,
+        columns=["value"],
+        granularity=None,
+        from_dttm=None,
+        to_dttm=None,
+        extras={},
+    )
+    GAMMA_FILTER_REGEX = re.compile(r"'[A,B,Q]%'")
+    BASE_FILTER_REGEX = re.compile(r"'boy'")
 
     def setUp(self):
         session = db.session
 
-        # Create the RowLevelSecurityFilter
-        self.rls_entry = RowLevelSecurityFilter()
-        self.rls_entry.tables.extend(
+        # Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
+        self.rls_entry1 = RowLevelSecurityFilter()
+        self.rls_entry1.tables.extend(
             session.query(SqlaTable)
             .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
             .all()
         )
-        self.rls_entry.clause = "value > {{ cache_key_wrapper(1) }}"
-        self.rls_entry.roles.append(
-            security_manager.find_role("Gamma")
-        )  # db.session.query(Role).filter_by(name="Gamma").first())
-        self.rls_entry.roles.append(security_manager.find_role("Alpha"))
-        db.session.add(self.rls_entry)
+        self.rls_entry1.filter_type = "Regular"
+        self.rls_entry1.clause = "value > {{ cache_key_wrapper(1) }}"
+        self.rls_entry1.group_key = None
+        self.rls_entry1.roles.append(security_manager.find_role("Gamma"))
+        self.rls_entry1.roles.append(security_manager.find_role("Alpha"))
+        db.session.add(self.rls_entry1)
+
+        # Create regular RowLevelSecurityFilter (birth_names name starts with 
A or B)
+        self.rls_entry2 = RowLevelSecurityFilter()
+        self.rls_entry2.tables.extend(
+            session.query(SqlaTable)
+            .filter(SqlaTable.table_name.in_(["birth_names"]))
+            .all()
+        )
+        self.rls_entry2.filter_type = "Regular"
+        self.rls_entry2.clause = "name like 'A%' or name like 'B%'"
+        self.rls_entry2.group_key = "name"
+        self.rls_entry2.roles.append(security_manager.find_role("Gamma"))
+        db.session.add(self.rls_entry2)
+
+        # Create Regular RowLevelSecurityFilter (birth_names name starts with 
Q)
+        self.rls_entry3 = RowLevelSecurityFilter()
+        self.rls_entry3.tables.extend(
+            session.query(SqlaTable)
+            .filter(SqlaTable.table_name.in_(["birth_names"]))
+            .all()
+        )
+        self.rls_entry3.filter_type = "Regular"
+        self.rls_entry3.clause = "name like 'Q%'"
+        self.rls_entry3.group_key = "name"
+        self.rls_entry3.roles.append(security_manager.find_role("Gamma"))
+        db.session.add(self.rls_entry3)
+
+        # Create Base RowLevelSecurityFilter (birth_names boys)
+        self.rls_entry4 = RowLevelSecurityFilter()
+        self.rls_entry4.tables.extend(
+            session.query(SqlaTable)
+            .filter(SqlaTable.table_name.in_(["birth_names"]))
+            .all()
+        )
+        self.rls_entry4.filter_type = "Base"
+        self.rls_entry4.clause = "gender = 'boy'"
+        self.rls_entry4.group_key = "gender"
+        self.rls_entry4.roles.append(security_manager.find_role("Admin"))
+        db.session.add(self.rls_entry4)
 
         db.session.commit()
 
     def tearDown(self):
         session = db.session
-        session.delete(self.rls_entry)
+        session.delete(self.rls_entry1)
+        session.delete(self.rls_entry2)
+        session.delete(self.rls_entry3)
+        session.delete(self.rls_entry4)
         session.commit()
 
-    # Do another test to make sure it doesn't alter another query
-    def test_rls_filter_alters_query(self):
-        g.user = self.get_user(
-            username="alpha"
-        )  # self.login() doesn't actually set the user
+    def test_rls_filter_alters_energy_query(self):
+        g.user = self.get_user(username="alpha")
         tbl = self.get_table_by_name("energy_usage")
-        query_obj = dict(
-            groupby=[],
-            metrics=[],
-            filter=[],
-            is_timeseries=False,
-            columns=["value"],
-            granularity=None,
-            from_dttm=None,
-            to_dttm=None,
-            extras={},
-        )
-        sql = tbl.get_query_str(query_obj)
-        assert tbl.get_extra_cache_keys(query_obj) == [1]
+        sql = tbl.get_query_str(self.query_obj)
+        assert tbl.get_extra_cache_keys(self.query_obj) == [1]
         assert "value > 1" in sql
 
-    def test_rls_filter_doesnt_alter_query(self):
+    def test_rls_filter_doesnt_alter_energy_query(self):
         g.user = self.get_user(
             username="admin"
         )  # self.login() doesn't actually set the user
         tbl = self.get_table_by_name("energy_usage")
-        query_obj = dict(
-            groupby=[],
-            metrics=[],
-            filter=[],
-            is_timeseries=False,
-            columns=["value"],
-            granularity=None,
-            from_dttm=None,
-            to_dttm=None,
-            extras={},
-        )
-        sql = tbl.get_query_str(query_obj)
-        assert tbl.get_extra_cache_keys(query_obj) == []
+        sql = tbl.get_query_str(self.query_obj)
+        assert tbl.get_extra_cache_keys(self.query_obj) == []
         assert "value > 1" not in sql
 
     def test_multiple_table_filter_alters_another_tables_query(self):
         g.user = self.get_user(
             username="alpha"
         )  # self.login() doesn't actually set the user
         tbl = self.get_table_by_name("unicode_test")
-        query_obj = dict(
-            groupby=[],
-            metrics=[],
-            filter=[],
-            is_timeseries=False,
-            columns=["value"],
-            granularity=None,
-            from_dttm=None,
-            to_dttm=None,
-            extras={},
-        )
-        sql = tbl.get_query_str(query_obj)
-        assert tbl.get_extra_cache_keys(query_obj) == [1]
+        sql = tbl.get_query_str(self.query_obj)
+        assert tbl.get_extra_cache_keys(self.query_obj) == [1]
         assert "value > 1" in sql
+
+    def test_rls_filter_alters_gamma_birth_names_query(self):
+        g.user = self.get_user(username="gamma")
+        tbl = self.get_table_by_name("birth_names")
+        sql = tbl.get_query_str(self.query_obj)
+
+        # establish that both regular and base filters are present
+        assert self.GAMMA_FILTER_REGEX.search(sql)
+        assert self.BASE_FILTER_REGEX.search(sql)
+
+        # establish that they are grouped together correctly with ANDs, ORs
+        # and parens in the correct place (only look for unique bits in the
+        # filters to make the regex simpler)
+        assert re.search(

Review comment:
       I changed the complex regex to a full where clause assertion, for the 
others I'm just checking that the exact clause is in the query.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to