suddjian commented on a change in pull request #19055:
URL: https://github.com/apache/superset/pull/19055#discussion_r823044342



##########
File path: superset/sql_parse.py
##########
@@ -458,3 +459,183 @@ def validate_filter_clause(clause: str) -> None:
                 )
     if open_parens > 0:
         raise QueryClauseValidationException("Unclosed parenthesis in filter 
clause")
+
+
+class InsertRLSState(str, Enum):
+    """
+    State machine that scans for WHERE and ON clauses referencing tables.
+    """
+
+    SCANNING = "SCANNING"
+    SEEN_SOURCE = "SEEN_SOURCE"
+    FOUND_TABLE = "FOUND_TABLE"
+
+
+def has_table_query(token_list: TokenList) -> bool:
+    """
+    Return if a stament has a query reading from a table.
+
+        >>> has_table_query(sqlparse.parse("COUNT(*)")[0])
+        False
+        >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
+        True
+
+    Note that queries reading from constant values return false:
+
+        >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
+        False
+
+    """
+    state = InsertRLSState.SCANNING
+    for token in token_list.tokens:
+
+        # # Recurse into child token list
+        if isinstance(token, TokenList) and has_table_query(token):
+            return True
+
+        # Found a source keyword (FROM/JOIN)
+        if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
+            state = InsertRLSState.SEEN_SOURCE
+
+        # Found identifier/keyword after FROM/JOIN
+        elif state == InsertRLSState.SEEN_SOURCE and (
+            isinstance(token, sqlparse.sql.Identifier) or token.ttype == 
Keyword
+        ):
+            return True
+
+        # Found nothing, leaving source
+        elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
+            state = InsertRLSState.SCANNING
+
+    return False
+
+
+def add_table_name(rls: TokenList, table: str) -> None:
+    """
+    Modify a RLS expression ensuring columns are fully qualified.
+    """
+    tokens = rls.tokens[:]
+    while tokens:
+        token = tokens.pop(0)
+
+        if isinstance(token, Identifier) and token.get_parent_name() is None:
+            token.tokens = [
+                Token(Name, table),
+                Token(Punctuation, "."),
+                Token(Name, token.get_name()),
+            ]
+        elif isinstance(token, TokenList):
+            tokens.extend(token.tokens)
+
+
+def matches_table_name(token: Token, table: str) -> bool:
+    """
+    Return the name of a table.
+
+    A table should be represented as an identifier, but due to sqlparse's 
aggressive list
+    of keywords (spanning multiple dialects) often it gets classified as a 
keyword.
+    """
+    candidate = token.value
+
+    # match from right to left, splitting on the period, eg, schema.table == 
table
+    for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]):
+        if left != right:
+            return False
+
+    return True
+
+
+def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
+    """
+    Update a statement inpalce applying an RLS associated with a given table.
+    """
+    # make sure the identifier has the table name
+    add_table_name(rls, table)
+
+    state = InsertRLSState.SCANNING
+    for token in token_list.tokens:
+
+        # Recurse into child token list
+        if isinstance(token, TokenList):
+            i = token_list.tokens.index(token)
+            token_list.tokens[i] = insert_rls(token, table, rls)
+
+        # Found a source keyword (FROM/JOIN)
+        if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
+            state = InsertRLSState.SEEN_SOURCE
+
+        # Found identifier/keyword after FROM/JOIN, test for table
+        elif state == InsertRLSState.SEEN_SOURCE and (
+            isinstance(token, Identifier) or token.ttype == Keyword
+        ):
+            if matches_table_name(token, table):
+                state = InsertRLSState.FOUND_TABLE
+
+                # found table at the end of the statement; append a WHERE 
clause
+                if token == token_list[-1]:
+                    token_list.tokens.extend(
+                        [
+                            Token(Whitespace, " "),
+                            Where(
+                                [Token(Keyword, "WHERE"), Token(Whitespace, " 
"), rls]
+                            ),
+                        ]
+                    )
+                    return token_list
+
+        # Found WHERE clause, insert RLS if not present
+        elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
+            if str(rls) not in {str(t) for t in token.tokens}:
+                token.tokens.extend(
+                    [
+                        Token(Whitespace, " "),
+                        Token(Keyword, "AND"),
+                        Token(Whitespace, " "),
+                    ]
+                    + rls.tokens
+                )
+            state = InsertRLSState.SCANNING
+
+        # Found ON clause, insert RLS if not present

Review comment:
       There's no check for the "if not present" part of this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to