betodealmeida commented on a change in pull request #19055:
URL: https://github.com/apache/superset/pull/19055#discussion_r823854840
##########
File path: superset/sql_parse.py
##########
@@ -458,3 +459,183 @@ def validate_filter_clause(clause: str) -> None:
)
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter
clause")
+
+
+class InsertRLSState(str, Enum):
+ """
+ State machine that scans for WHERE and ON clauses referencing tables.
+ """
+
+ SCANNING = "SCANNING"
+ SEEN_SOURCE = "SEEN_SOURCE"
+ FOUND_TABLE = "FOUND_TABLE"
+
+
+def has_table_query(token_list: TokenList) -> bool:
+ """
+ Return if a stament has a query reading from a table.
+
+ >>> has_table_query(sqlparse.parse("COUNT(*)")[0])
+ False
+ >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
+ True
+
+ Note that queries reading from constant values return false:
+
+ >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
+ False
+
+ """
+ state = InsertRLSState.SCANNING
+ for token in token_list.tokens:
+
+ # # Recurse into child token list
+ if isinstance(token, TokenList) and has_table_query(token):
+ return True
+
+ # Found a source keyword (FROM/JOIN)
+ if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
+ state = InsertRLSState.SEEN_SOURCE
+
+ # Found identifier/keyword after FROM/JOIN
+ elif state == InsertRLSState.SEEN_SOURCE and (
+ isinstance(token, sqlparse.sql.Identifier) or token.ttype ==
Keyword
+ ):
+ return True
+
+ # Found nothing, leaving source
+ elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
+ state = InsertRLSState.SCANNING
+
+ return False
+
+
+def add_table_name(rls: TokenList, table: str) -> None:
+ """
+ Modify a RLS expression ensuring columns are fully qualified.
+ """
+ tokens = rls.tokens[:]
+ while tokens:
+ token = tokens.pop(0)
+
+ if isinstance(token, Identifier) and token.get_parent_name() is None:
+ token.tokens = [
+ Token(Name, table),
+ Token(Punctuation, "."),
+ Token(Name, token.get_name()),
+ ]
+ elif isinstance(token, TokenList):
+ tokens.extend(token.tokens)
+
+
+def matches_table_name(token: Token, table: str) -> bool:
+ """
+ Return the name of a table.
+
+ A table should be represented as an identifier, but due to sqlparse's
aggressive list
+ of keywords (spanning multiple dialects) often it gets classified as a
keyword.
+ """
+ candidate = token.value
+
+ # match from right to left, splitting on the period, eg, schema.table ==
table
+ for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]):
Review comment:
If we scan left to right then `schema.table != table`, since we'd
compare `schema` with `table`. Doing it right to left we compare `table` with
`table`, and stop because one side has no more tokens.
The goal here is to be conservative. If we have an RLS on
`my_schema.my_table` and the query references only `my_table`, without the
explicit schema, we don't know if they represent the same table unless we know
the context in which the query is being executed (ie, what is the default
schema when the query runs). This method will err on the side of applying the
RLS to tables that maybe shouldn't, instead of not applying when it should.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]