betodealmeida commented on a change in pull request #19055:
URL: https://github.com/apache/superset/pull/19055#discussion_r824762461



##########
File path: superset/sql_parse.py
##########
@@ -458,3 +459,183 @@ def validate_filter_clause(clause: str) -> None:
                 )
     if open_parens > 0:
         raise QueryClauseValidationException("Unclosed parenthesis in filter 
clause")
+
+
+class InsertRLSState(str, Enum):
+    """
+    State machine that scans for WHERE and ON clauses referencing tables.
+    """
+
+    SCANNING = "SCANNING"
+    SEEN_SOURCE = "SEEN_SOURCE"
+    FOUND_TABLE = "FOUND_TABLE"
+
+
+def has_table_query(token_list: TokenList) -> bool:
+    """
+    Return if a stament has a query reading from a table.
+
+        >>> has_table_query(sqlparse.parse("COUNT(*)")[0])
+        False
+        >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
+        True
+
+    Note that queries reading from constant values return false:
+
+        >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
+        False
+
+    """
+    state = InsertRLSState.SCANNING
+    for token in token_list.tokens:
+
+        # # Recurse into child token list
+        if isinstance(token, TokenList) and has_table_query(token):
+            return True
+
+        # Found a source keyword (FROM/JOIN)
+        if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
+            state = InsertRLSState.SEEN_SOURCE
+
+        # Found identifier/keyword after FROM/JOIN
+        elif state == InsertRLSState.SEEN_SOURCE and (
+            isinstance(token, sqlparse.sql.Identifier) or token.ttype == 
Keyword
+        ):
+            return True
+
+        # Found nothing, leaving source
+        elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
+            state = InsertRLSState.SCANNING
+
+    return False
+
+
+def add_table_name(rls: TokenList, table: str) -> None:
+    """
+    Modify a RLS expression ensuring columns are fully qualified.
+    """
+    tokens = rls.tokens[:]
+    while tokens:
+        token = tokens.pop(0)
+
+        if isinstance(token, Identifier) and token.get_parent_name() is None:
+            token.tokens = [
+                Token(Name, table),
+                Token(Punctuation, "."),
+                Token(Name, token.get_name()),
+            ]
+        elif isinstance(token, TokenList):
+            tokens.extend(token.tokens)
+
+
+def matches_table_name(token: Token, table: str) -> bool:
+    """
+    Return the name of a table.
+
+    A table should be represented as an identifier, but due to sqlparse's 
aggressive list
+    of keywords (spanning multiple dialects) often it gets classified as a 
keyword.
+    """
+    candidate = token.value
+
+    # match from right to left, splitting on the period, eg, schema.table == 
table
+    for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]):

Review comment:
       Ah, good point, @villebro! I'll add more tests.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to