vinodkc commented on code in PR #39449:
URL: https://github.com/apache/spark/pull/39449#discussion_r1087034137


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -257,19 +271,272 @@ case class Mask(
       otherChar = newChildren(4))
 }
 
-case class MaskArgument(maskChar: Char, ignore: Boolean)
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage =
+    """_FUNC_(input[, charCount, upperChar, lowerChar, digitChar, otherChar]) 
- masks the first n characters of given string value.
+       The function masks the first n characters of the value with 'X' or 'x', 
and numbers with 'n'.
+       This can be useful for creating copies of tables with sensitive 
information removed.
+       Error behavior: null value as replacement argument will throw 
AnalysisError.
+      """,
+  arguments = """
+    Arguments:
+      * input      - string value to mask. Supported types: STRING, VARCHAR, 
CHAR
+      * charCount  - number of characters to be masked. Default value: 4
+      * upperChar  - character to replace upper-case characters with. Specify 
NULL to retain original character. Default value: 'X'
+      * lowerChar  - character to replace lower-case characters with. Specify 
NULL to retain original character. Default value: 'x'
+      * digitChar  - character to replace digit characters with. Specify NULL 
to retain original character. Default value: 'n'
+      * otherChar  - character to replace all other characters with. Specify 
NULL to retain original character. Default value: NULL
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('abcd-EFGH-8765-4321');
+        xxxx-EFGH-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-4321', 9);
+        xxxx-XXXX-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 14);
+        xxxx-XXXX-nnnn-@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 15, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnno@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 20, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnnoooo
+      > SELECT _FUNC_('AbCD123-@$#', 10,'Q', 'q', 'd', 'o');
+        QqQQdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, 'q', 'd', 'o');
+        AqCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, 'd', 'o');
+        AbCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, NULL, 'o');
+        AbCD123ooo#
+      > SELECT _FUNC_(NULL);
+        NULL
+      > SELECT _FUNC_(NULL, 1, NULL, NULL, 'o');
+        NULL
+  """,
+  since = "3.4.0",
+  group = "string_funcs")
+// scalastyle:on line.size.limit
+case class MaskFirstN(
+    input: Expression,
+    charCountExpr: Expression,
+    upperChar: Expression,
+    lowerChar: Expression,
+    digitChar: Expression,
+    otherChar: Expression)
+    extends SeptenaryExpression
+    with Maskable
+    with ExpectsInputTypes
+    with QueryErrorsBase {
+
+  def this(input: Expression) =
+    this(
+      input,
+      Literal(Mask.DEFAULT_CHAR_COUNT),
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression) =
+    this(
+      input,
+      charCountExpr,
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression, upperChar: 
Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression,
+      digitChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      digitChar,
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  @transient
+  private lazy val charCount = {
+    val value = charCountExpr.eval().asInstanceOf[Int]
+    if (value < 0) 0 else value
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    validateInputDataTypes(
+      super.checkInputDataTypes(),
+      Seq(
+        (upperChar, "upperChar"),
+        (lowerChar, "lowerChar"),
+        (digitChar, "digitChar"),
+        (otherChar, "otherChar")),
+      () =>

Review Comment:
   Moved that method to a template pattern



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala:
##########
@@ -257,19 +271,272 @@ case class Mask(
       otherChar = newChildren(4))
 }
 
-case class MaskArgument(maskChar: Char, ignore: Boolean)
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage =
+    """_FUNC_(input[, charCount, upperChar, lowerChar, digitChar, otherChar]) 
- masks the first n characters of given string value.
+       The function masks the first n characters of the value with 'X' or 'x', 
and numbers with 'n'.
+       This can be useful for creating copies of tables with sensitive 
information removed.
+       Error behavior: null value as replacement argument will throw 
AnalysisError.
+      """,
+  arguments = """
+    Arguments:
+      * input      - string value to mask. Supported types: STRING, VARCHAR, 
CHAR
+      * charCount  - number of characters to be masked. Default value: 4
+      * upperChar  - character to replace upper-case characters with. Specify 
NULL to retain original character. Default value: 'X'
+      * lowerChar  - character to replace lower-case characters with. Specify 
NULL to retain original character. Default value: 'x'
+      * digitChar  - character to replace digit characters with. Specify NULL 
to retain original character. Default value: 'n'
+      * otherChar  - character to replace all other characters with. Specify 
NULL to retain original character. Default value: NULL
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('abcd-EFGH-8765-4321');
+        xxxx-EFGH-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-4321', 9);
+        xxxx-XXXX-8765-4321
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 14);
+        xxxx-XXXX-nnnn-@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 15, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnno@$#
+      > SELECT _FUNC_('abcd-EFGH-8765-@$#', 20, 'x', 'X', 'n', 'o');
+        XXXXoxxxxonnnnoooo
+      > SELECT _FUNC_('AbCD123-@$#', 10,'Q', 'q', 'd', 'o');
+        QqQQdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, 'q', 'd', 'o');
+        AqCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, 'd', 'o');
+        AbCDdddooo#
+      > SELECT _FUNC_('AbCD123-@$#', 10, NULL, NULL, NULL, 'o');
+        AbCD123ooo#
+      > SELECT _FUNC_(NULL);
+        NULL
+      > SELECT _FUNC_(NULL, 1, NULL, NULL, 'o');
+        NULL
+  """,
+  since = "3.4.0",
+  group = "string_funcs")
+// scalastyle:on line.size.limit
+case class MaskFirstN(
+    input: Expression,
+    charCountExpr: Expression,
+    upperChar: Expression,
+    lowerChar: Expression,
+    digitChar: Expression,
+    otherChar: Expression)
+    extends SeptenaryExpression
+    with Maskable
+    with ExpectsInputTypes
+    with QueryErrorsBase {
+
+  def this(input: Expression) =
+    this(
+      input,
+      Literal(Mask.DEFAULT_CHAR_COUNT),
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression) =
+    this(
+      input,
+      charCountExpr,
+      Literal(Mask.MASKED_UPPERCASE),
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(input: Expression, charCountExpr: Expression, upperChar: 
Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      Literal(Mask.MASKED_LOWERCASE),
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      Literal(Mask.MASKED_DIGIT),
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  def this(
+      input: Expression,
+      charCountExpr: Expression,
+      upperChar: Expression,
+      lowerChar: Expression,
+      digitChar: Expression) =
+    this(
+      input,
+      charCountExpr,
+      upperChar,
+      lowerChar,
+      digitChar,
+      Literal(Mask.MASKED_IGNORE, StringType))
+
+  @transient
+  private lazy val charCount = {
+    val value = charCountExpr.eval().asInstanceOf[Int]
+    if (value < 0) 0 else value
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    validateInputDataTypes(
+      super.checkInputDataTypes(),
+      Seq(
+        (upperChar, "upperChar"),
+        (lowerChar, "lowerChar"),
+        (digitChar, "digitChar"),
+        (otherChar, "otherChar")),
+      () =>
+        Seq(if (!charCountExpr.foldable) {
+          Some(
+            DataTypeMismatch(
+              errorSubClass = "NON_FOLDABLE_INPUT",
+              messageParameters = Map(
+                "inputName" -> "charCount",
+                "inputType" -> toSQLType(charCountExpr.dataType),
+                "inputExpr" -> toSQLExpr(charCountExpr))))
+        } else if (charCountExpr.eval() == null) {
+          Some(
+            DataTypeMismatch(
+              errorSubClass = "UNEXPECTED_NULL",
+              messageParameters = Map("exprName" -> "charCount")))
+        } else {
+          None
+        }))
+
+  /**
+   * Expected input types from child expressions. The i-th position in the 
returned seq indicates
+   * the type requirement for the i-th child.
+   *
+   * The possible values at each position are:
+   *   1. a specific data type, e.g. LongType, StringType. 2. a non-leaf 
abstract data type, e.g.
+   *      NumericType, IntegralType, FractionalType.
+   */
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringType, IntegerType, StringType, StringType, StringType, 
StringType)
+
+  override def nullable: Boolean = true
+
+  /**
+   * Default behavior of evaluation according to the default nullability of 
QuinaryExpression. If
+   * subclass of QuinaryExpression override nullable, probably should also 
override this.
+   */
+  override def eval(input: InternalRow): Any = {
+    Mask.mask_first_n(
+      children(0).eval(input),
+      charCount,
+      children(2).eval(input),

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to