This is an automated email from the ASF dual-hosted git repository.

liuxun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new 81dca71  SUBMARINE-531. Fix data masking for nested fields with 
functions
81dca71 is described below

commit 81dca71416f8605f765cdb250a6c776f16fdf59e
Author: Kent Yao <yaooq...@hotmail.com>
AuthorDate: Mon Jun 29 21:46:10 2020 +0800

    SUBMARINE-531. Fix data masking for nested fields with functions
    
    ### What is this PR for?
    
    This PR fix #313, where data masking works incorrectly with function-nested 
fileds
    
    ### What type of PR is it?
    Bug Fix
    ### Todos
    * [ ] - Task
    
    ### What is the Jira issue?
    * Open an issue on Jira https://issues.apache.org/jira/browse/SUBMARINE-531
    
    ### How should this be tested?
    
    add more unit tests
    
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Does the licenses files need update? No
    * Is there breaking changes for older versions? NO
    * Does this needs documentation? NO
    
    Author: Kent Yao <yaooq...@hotmail.com>
    
    Closes #332 from yaooqinn/SUBMARINE-531 and squashes the following commits:
    
    c8cb492 [Kent Yao] Merge branch 'master' into SUBMARINE-531
    577fce2 [Kent Yao] SUBMARINE-531. Fix data masking for nested fields with 
functions
    fffed1e [Kent Yao] SUBMARINE-489. DCL Framework: SHOW CURRENT ROLES syntax
---
 .../optimizer/SubmarineDataMaskingExtension.scala  | 30 +++++++++++++++++-----
 .../spark/security/DataMaskingSQLTest.scala        | 22 ++++++++++++++++
 2 files changed, 46 insertions(+), 6 deletions(-)

diff --git 
a/submarine-security/spark-security/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SubmarineDataMaskingExtension.scala
 
b/submarine-security/spark-security/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SubmarineDataMaskingExtension.scala
index 5dad02a..0132385 100644
--- 
a/submarine-security/spark-security/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SubmarineDataMaskingExtension.scala
+++ 
b/submarine-security/spark-security/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SubmarineDataMaskingExtension.scala
@@ -129,7 +129,8 @@ case class SubmarineDataMaskingExtension(spark: 
SparkSession) extends Rule[Logic
   private def collectTransformers(
       plan: LogicalPlan,
       table: CatalogTable,
-      aliases: mutable.Map[Alias, ExprId]): Map[ExprId, NamedExpression] = {
+      aliases: mutable.Map[Alias, ExprId],
+      outputs: Seq[NamedExpression]): Map[ExprId, NamedExpression] = {
     try {
       val maskEnableResults = plan.output.map { expr =>
         expr -> getAccessResult(table.identifier, expr)
@@ -140,6 +141,18 @@ case class SubmarineDataMaskingExtension(spark: 
SparkSession) extends Rule[Logic
       }.filter(_._2 != null).toMap
 
       val aliasedMaskers = new mutable.HashMap[ExprId, Alias]()
+
+      for (output <- outputs) {
+        val newOutput = output transformUp {
+          case ar: AttributeReference => formedMaskers.getOrElse(ar.exprId, ar)
+        }
+
+        if (!output.equals(newOutput)) {
+          val newAlias = Alias(newOutput, output.name)()
+          aliasedMaskers.put(output.exprId, newAlias)
+        }
+      }
+
       for ((alias, id) <- aliases if formedMaskers.contains(id)) {
         val originalAlias = formedMaskers(id)
         val newChild = originalAlias.child mapChildren {
@@ -171,10 +184,10 @@ case class SubmarineDataMaskingExtension(spark: 
SparkSession) extends Rule[Logic
     val aliases = new mutable.HashMap[Alias, ExprId]()
     plan.transformAllExpressions {
       case a: Alias =>
-        a.child match {
+        a.child transformUp {
           case ne: NamedExpression =>
-            aliases.put(a, ne.exprId)
-          case _ =>
+            aliases.getOrElseUpdate(a, ne.exprId)
+            ne
         }
         a
     }
@@ -184,11 +197,16 @@ case class SubmarineDataMaskingExtension(spark: 
SparkSession) extends Rule[Logic
   private def collectAllTransformers(
       plan: LogicalPlan,
       aliases: mutable.Map[Alias, ExprId]): Map[ExprId, NamedExpression] = {
+    val outputs = plan match {
+      case p: Project => p.projectList
+      case o => o.output
+    }
+
     plan.collectLeaves().flatMap {
       case h: HiveTableRelation =>
-        collectTransformers(h, h.tableMeta, aliases)
+        collectTransformers(h, h.tableMeta, aliases, outputs)
       case l: LogicalRelation if l.catalogTable.isDefined =>
-        collectTransformers(l, l.catalogTable.get, aliases)
+        collectTransformers(l, l.catalogTable.get, aliases, outputs)
       case _ => Seq.empty
     }.toMap
   }
diff --git 
a/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/DataMaskingSQLTest.scala
 
b/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/DataMaskingSQLTest.scala
index f228845..b33bd5a 100644
--- 
a/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/DataMaskingSQLTest.scala
+++ 
b/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/DataMaskingSQLTest.scala
@@ -188,6 +188,28 @@ case class DataMaskingSQLTest() extends FunSuite with 
BeforeAndAfterAll {
     }
   }
 
+  test("MASK_SHOW_LAST_4 and functions") {
+    val statement =
+      s"""
+         |select
+         | key,
+         | value,
+         | substr(value, 0, 18),
+         | substr(value, 0, 18) as v1,
+         | substr(cast(value as string), 0, 18) as v2
+         | from default.rangertbl5 where value = 'val_277'""".stripMargin
+    withUser("bob") {
+      val df = sql(statement)
+      println(df.queryExecution.optimizedPlan)
+      val row = df.take(1)(0)
+      assert(row.getString(1) === "xxx_277", "value shows last 4 characters")
+      assert(row.getString(2) === "xxx_277", "value shows last 4 characters")
+      assert(row.getString(3) === "xxx_277", "value shows last 4 characters")
+      assert(row.getString(4) === "xxx_277", "value shows last 4 characters")
+
+    }
+  }
+
   test("NO MASKING") {
     val statement = "select * from default.rangertbl6 where value = 'val_277'"
     withUser("bob") {


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@submarine.apache.org
For additional commands, e-mail: dev-h...@submarine.apache.org

Reply via email to