villebro commented on a change in pull request #19055:
URL: https://github.com/apache/superset/pull/19055#discussion_r824720026



##########
File path: tests/unit_tests/sql_parse_tests.py
##########
@@ -1189,3 +1193,225 @@ def test_sqlparse_issue_652():
     stmt = sqlparse.parse(r"foo = '\' AND bar = 'baz'")[0]
     assert len(stmt.tokens) == 5
     assert str(stmt.tokens[0]) == "foo = '\\'"
+
+
[email protected](
+    "sql,expected",
+    [
+        ("SELECT * FROM table", True),
+        ("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
+        ("(SELECT COUNT(DISTINCT name) AS foo FROM    birth_names)", True),
+        ("COUNT(*)", False),
+        ("SELECT a FROM (SELECT 1 AS a)", False),
+        ("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
+        ("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", 
False),
+        ("SELECT * FROM other_table", True),
+    ],
+)

Review comment:
       Could we add a test case for that `EXTRACT` expression above: 
`("extract(HOUR from from_unixtime(hour_ts)", False)` which caued a downgrade 
of `sqlparse` a while back: #10165. I tested that it fails currently - not sure 
if that's a problem or if we can live with it?

##########
File path: superset/sql_parse.py
##########
@@ -458,3 +459,199 @@ def validate_filter_clause(clause: str) -> None:
                 )
     if open_parens > 0:
         raise QueryClauseValidationException("Unclosed parenthesis in filter 
clause")
+
+
+class InsertRLSState(str, Enum):
+    """
+    State machine that scans for WHERE and ON clauses referencing tables.
+    """
+
+    SCANNING = "SCANNING"
+    SEEN_SOURCE = "SEEN_SOURCE"
+    FOUND_TABLE = "FOUND_TABLE"
+
+
+def has_table_query(token_list: TokenList) -> bool:
+    """
+    Return if a stament has a query reading from a table.
+
+        >>> has_table_query(sqlparse.parse("COUNT(*)")[0])
+        False
+        >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
+        True
+
+    Note that queries reading from constant values return false:
+
+        >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
+        False
+
+    """
+    state = InsertRLSState.SCANNING
+    for token in token_list.tokens:
+
+        # # Recurse into child token list
+        if isinstance(token, TokenList) and has_table_query(token):
+            return True
+
+        # Found a source keyword (FROM/JOIN)
+        if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
+            state = InsertRLSState.SEEN_SOURCE
+
+        # Found identifier/keyword after FROM/JOIN
+        elif state == InsertRLSState.SEEN_SOURCE and (
+            isinstance(token, sqlparse.sql.Identifier) or token.ttype == 
Keyword
+        ):
+            return True
+
+        # Found nothing, leaving source
+        elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
+            state = InsertRLSState.SCANNING
+
+    return False
+
+
+def add_table_name(rls: TokenList, table: str) -> None:
+    """
+    Modify a RLS expression ensuring columns are fully qualified.
+    """
+    tokens = rls.tokens[:]
+    while tokens:
+        token = tokens.pop(0)
+
+        if isinstance(token, Identifier) and token.get_parent_name() is None:
+            token.tokens = [
+                Token(Name, table),
+                Token(Punctuation, "."),
+                Token(Name, token.get_name()),
+            ]
+        elif isinstance(token, TokenList):
+            tokens.extend(token.tokens)
+
+
+def matches_table_name(token: Token, table: str) -> bool:
+    """
+    Returns if the token represents a reference to the table.
+
+    Tables can be fully qualified with periods.
+
+    Note that in theory a table should be represented as an identifier, but 
due to
+    sqlparse's aggressive list of keywords (spanning multiple dialects) often 
it gets
+    classified as a keyword.
+    """
+    candidate = token.value
+
+    # match from right to left, splitting on the period, eg, schema.table == 
table
+    for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]):
+        if left != right:
+            return False
+
+    return True
+
+
+def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
+    """
+    Update a statement inpalce applying an RLS associated with a given table.

Review comment:
       typo
   ```suggestion
       Update a statement inplace applying an RLS associated with a given table.
   ```

##########
File path: superset/sql_parse.py
##########
@@ -458,3 +459,183 @@ def validate_filter_clause(clause: str) -> None:
                 )
     if open_parens > 0:
         raise QueryClauseValidationException("Unclosed parenthesis in filter 
clause")
+
+
+class InsertRLSState(str, Enum):
+    """
+    State machine that scans for WHERE and ON clauses referencing tables.
+    """
+
+    SCANNING = "SCANNING"
+    SEEN_SOURCE = "SEEN_SOURCE"
+    FOUND_TABLE = "FOUND_TABLE"
+
+
+def has_table_query(token_list: TokenList) -> bool:
+    """
+    Return if a stament has a query reading from a table.
+
+        >>> has_table_query(sqlparse.parse("COUNT(*)")[0])
+        False
+        >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
+        True
+
+    Note that queries reading from constant values return false:
+
+        >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
+        False
+
+    """
+    state = InsertRLSState.SCANNING
+    for token in token_list.tokens:
+
+        # # Recurse into child token list
+        if isinstance(token, TokenList) and has_table_query(token):
+            return True
+
+        # Found a source keyword (FROM/JOIN)
+        if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
+            state = InsertRLSState.SEEN_SOURCE
+
+        # Found identifier/keyword after FROM/JOIN
+        elif state == InsertRLSState.SEEN_SOURCE and (
+            isinstance(token, sqlparse.sql.Identifier) or token.ttype == 
Keyword
+        ):
+            return True
+
+        # Found nothing, leaving source
+        elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
+            state = InsertRLSState.SCANNING
+
+    return False
+
+
+def add_table_name(rls: TokenList, table: str) -> None:
+    """
+    Modify a RLS expression ensuring columns are fully qualified.
+    """
+    tokens = rls.tokens[:]
+    while tokens:
+        token = tokens.pop(0)
+
+        if isinstance(token, Identifier) and token.get_parent_name() is None:
+            token.tokens = [
+                Token(Name, table),
+                Token(Punctuation, "."),
+                Token(Name, token.get_name()),
+            ]
+        elif isinstance(token, TokenList):
+            tokens.extend(token.tokens)
+
+
+def matches_table_name(token: Token, table: str) -> bool:
+    """
+    Return the name of a table.
+
+    A table should be represented as an identifier, but due to sqlparse's 
aggressive list
+    of keywords (spanning multiple dialects) often it gets classified as a 
keyword.
+    """
+    candidate = token.value
+
+    # match from right to left, splitting on the period, eg, schema.table == 
table
+    for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]):

Review comment:
       edge case: WIth more and more dbs nowadays referencing files as tables 
that often include periods, I wonder if we need to consider quoted entities, 
like `schema."table.parquet"`, in which case the right most element would be 
`table.parquet`, not `parquet"`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to