This is an automated email from the ASF dual-hosted git repository.

sarutak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8f6e439  [SPARK-37586][SQL] Add the `mode` and `padding` args to 
`aes_encrypt()`/`aes_decrypt()`
8f6e439 is described below

commit 8f6e439068281633acefb895f8c4bd9203868c24
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Thu Dec 9 14:36:47 2021 +0900

    [SPARK-37586][SQL] Add the `mode` and `padding` args to 
`aes_encrypt()`/`aes_decrypt()`
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to add new optional arguments to the `aes_encrypt()` 
and `aes_decrypt()` functions with default values:
    1. `mode` - specifies which block cipher mode should be used to 
encrypt/decrypt messages. Only one valid value is `ECB` at the moment.
    2. `padding` - specifies how to pad messages whose length is not a multiple 
of the block size. Currently, only `PKCS` is supported.
    
    In this way, when an user doesn't pass `mode`/`padding` to the functions, 
the functions apply AES encryption/decryption in the `ECB` mode with the 
`PKCS5Padding` padding.
    
    ### Why are the changes needed?
    1. For now, `aes_encrypt()` and `aes_decrypt()` rely on the jvm's 
configuration regarding which cipher mode to support, this is problematic as it 
is not fixed across versions and systems. By using default constants for new 
arguments, we can guarantee the same behaviour across all supported platforms.
    2. We can consider new arguments as new point of extension in the current 
implementation of AES algorithm in Spark SQL. In the future in OSS or in a 
private Spark fork, devs can implement other modes (and paddings) like GCM. 
Other systems have already supported different AES modes, see:
       1. Snowflake: 
https://docs.snowflake.com/en/sql-reference/functions/encrypt.html
       2. BigQuery: 
https://cloud.google.com/bigquery/docs/reference/standard-sql/aead-encryption-concepts#block_cipher_modes
       3. MySQL: 
https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_aes-encrypt
       4. Hive: 
https://cwiki.apache.org/confluence/display/hive/languagemanual+udf
       5. PostgreSQL: 
https://www.postgresql.org/docs/12/pgcrypto.html#id-1.11.7.34.8
    
    ### Does this PR introduce _any_ user-facing change?
    No. This PR just extends existing APIs.
    
    ### How was this patch tested?
    By running new checks:
    ```
    $ build/sbt "test:testOnly org.apache.spark.sql.DataFrameFunctionsSuite"
    $ build/sbt "sql/test:testOnly 
org.apache.spark.sql.expressions.ExpressionInfoSuite"
    $ build/sbt "sql/testOnly *ExpressionsSchemaSuite"
    ```
    
    Closes #34837 from MaxGekk/aes-gsm-mode.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Kousuke Saruta <saru...@oss.nttdata.com>
---
 .../catalyst/expressions/ExpressionImplUtils.java  | 24 +++++--
 .../spark/sql/catalyst/expressions/misc.scala      | 78 +++++++++++++++++-----
 .../spark/sql/errors/QueryExecutionErrors.scala    | 10 ++-
 .../sql-functions/sql-expression-schema.md         |  2 +-
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 16 +++++
 5 files changed, 104 insertions(+), 26 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
index 9afa5a6..83205c1 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions;
 
 import org.apache.spark.sql.errors.QueryExecutionErrors;
+import org.apache.spark.unsafe.types.UTF8String;
 
 import javax.crypto.Cipher;
 import javax.crypto.spec.SecretKeySpec;
@@ -27,19 +28,28 @@ import java.security.GeneralSecurityException;
  * An utility class for constructing expressions.
  */
 public class ExpressionImplUtils {
-  public static byte[] aesEncrypt(byte[] input, byte[] key) {
-    return aesInternal(input, key, Cipher.ENCRYPT_MODE);
+  public static byte[] aesEncrypt(byte[] input, byte[] key, UTF8String mode, 
UTF8String padding) {
+    return aesInternal(input, key, mode.toString(), padding.toString(), 
Cipher.ENCRYPT_MODE);
   }
 
-  public static byte[] aesDecrypt(byte[] input, byte[] key) {
-    return aesInternal(input, key, Cipher.DECRYPT_MODE);
+  public static byte[] aesDecrypt(byte[] input, byte[] key, UTF8String mode, 
UTF8String padding) {
+    return aesInternal(input, key, mode.toString(), padding.toString(), 
Cipher.DECRYPT_MODE);
   }
 
-  private static byte[] aesInternal(byte[] input, byte[] key, int mode) {
+  private static byte[] aesInternal(
+      byte[] input,
+      byte[] key,
+      String mode,
+      String padding,
+      int opmode) {
     int inputLength = input.length;
     int keyLength = key.length;
     SecretKeySpec secretKey;
 
+    if (!mode.equalsIgnoreCase("ECB") || !padding.equalsIgnoreCase("PKCS")) {
+      throw QueryExecutionErrors.aesModeUnsupportedError(mode, padding);
+    }
+
     switch (keyLength) {
       case 16:
       case 24:
@@ -51,8 +61,8 @@ public class ExpressionImplUtils {
       }
 
     try {
-      Cipher cipher = Cipher.getInstance("AES");
-      cipher.init(mode, secretKey);
+      Cipher cipher = Cipher.getInstance("AES/ECB/PKCS5Padding");
+      cipher.init(opmode, secretKey);
       return cipher.doFinal(input, 0, inputLength);
     } catch (GeneralSecurityException e) {
         throw new RuntimeException(e);
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 02eae3f..d891623 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -310,33 +310,57 @@ case class CurrentUser() extends LeafExpression with 
Unevaluable {
  * If either argument is NULL or the key length is not one of the permitted 
values,
  * the return value is NULL.
  */
+// scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
-    _FUNC_(expr, key) - Returns an encrypted value of `expr` using AES.
+    _FUNC_(expr, key[, mode[, padding]]) - Returns an encrypted value of 
`expr` using AES in given `mode` with the specified `padding`.
       Key lengths of 16, 24 and 32 bits are supported.
   """,
+  arguments = """
+    Arguments:
+      * expr - The binary value to encrypt.
+      * key - The passphrase to use to encrypt the data.
+      * mode - Specifies which block cipher mode should be used to encrypt 
messages.
+               Supported modes: ECB.
+      * padding - Specifies how to pad messages whose length is not a multiple 
of the block size.
+                  Valid values: PKCS.
+  """,
   examples = """
     Examples:
       > SELECT base64(_FUNC_('Spark', 'abcdefghijklmnop'));
        4Hv0UKCx6nfUeAoPZo1z+w==
+      > SELECT base64(_FUNC_('Spark SQL', '1234567890abcdef', 'ECB', 'PKCS'));
+       3lmwu+Mw0H3fi5NDvcu9lg==
   """,
   since = "3.3.0",
   group = "misc_funcs")
-case class AesEncrypt(input: Expression, key: Expression, child: Expression)
-    extends RuntimeReplaceable {
+case class AesEncrypt(
+    input: Expression,
+    key: Expression,
+    mode: Expression,
+    padding: Expression,
+    child: Expression)
+  extends RuntimeReplaceable {
 
-  def this(input: Expression, key: Expression) = {
-    this(input,
+  def this(input: Expression, key: Expression, mode: Expression, padding: 
Expression) = {
+    this(
+      input,
       key,
+      mode,
+      padding,
       StaticInvoke(
         classOf[ExpressionImplUtils],
         BinaryType,
         "aesEncrypt",
-        Seq(input, key),
-        Seq(BinaryType, BinaryType)))
+        Seq(input, key, mode, padding),
+        Seq(BinaryType, BinaryType, StringType, StringType)))
   }
+  def this(input: Expression, key: Expression, mode: Expression) =
+    this(input, key, mode, Literal("PKCS"))
+  def this(input: Expression, key: Expression) =
+    this(input, key, Literal("ECB"))
 
-  def exprsReplaced: Seq[Expression] = Seq(input, key)
+  def exprsReplaced: Seq[Expression] = Seq(input, key, mode, padding)
   protected def withNewChildInternal(newChild: Expression): AesEncrypt =
     copy(child = newChild)
 }
@@ -350,31 +374,55 @@ case class AesEncrypt(input: Expression, key: Expression, 
child: Expression)
  */
 @ExpressionDescription(
   usage = """
-    _FUNC_(expr, key) - Returns a decrepted value of `expr` using AES.
+    _FUNC_(expr, key[, mode[, padding]]) - Returns a decrepted value of `expr` 
using AES in `mode` with `padding`.
       Key lengths of 16, 24 and 32 bits are supported.
   """,
+  arguments = """
+    Arguments:
+      * expr - The binary value to decrypt.
+      * key - The passphrase to use to decrypt the data.
+      * mode - Specifies which block cipher mode should be used to decrypt 
messages.
+               Valid modes: ECB.
+      * padding - Specifies how to pad messages whose length is not a multiple 
of the block size.
+                  Valid values: PKCS.
+  """,
   examples = """
     Examples:
       > SELECT _FUNC_(unbase64('4Hv0UKCx6nfUeAoPZo1z+w=='), 
'abcdefghijklmnop');
        Spark
+      > SELECT _FUNC_(unbase64('3lmwu+Mw0H3fi5NDvcu9lg=='), 
'1234567890abcdef', 'ECB', 'PKCS');
+       Spark SQL
   """,
   since = "3.3.0",
   group = "misc_funcs")
-case class AesDecrypt(input: Expression, key: Expression, child: Expression)
-    extends RuntimeReplaceable {
+case class AesDecrypt(
+    input: Expression,
+    key: Expression,
+    mode: Expression,
+    padding: Expression,
+    child: Expression)
+  extends RuntimeReplaceable {
 
-  def this(input: Expression, key: Expression) = {
-    this(input,
+  def this(input: Expression, key: Expression, mode: Expression, padding: 
Expression) = {
+    this(
+      input,
       key,
+      mode,
+      padding,
       StaticInvoke(
         classOf[ExpressionImplUtils],
         BinaryType,
         "aesDecrypt",
-        Seq(input, key),
-        Seq(BinaryType, BinaryType)))
+        Seq(input, key, mode, padding),
+        Seq(BinaryType, BinaryType, StringType, StringType)))
   }
+  def this(input: Expression, key: Expression, mode: Expression) =
+    this(input, key, mode, Literal("PKCS"))
+  def this(input: Expression, key: Expression) =
+    this(input, key, Literal("ECB"))
 
   def exprsReplaced: Seq[Expression] = Seq(input, key)
   protected def withNewChildInternal(newChild: Expression): AesDecrypt =
     copy(child = newChild)
 }
+// scalastyle:on line.size.limit
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index d307aa9..a316ebc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1893,9 +1893,13 @@ object QueryExecutionErrors {
   }
 
   def invalidAesKeyLengthError(actualLength: Int): RuntimeException = {
-    new RuntimeException(
-      s"The key length of aes_encrypt/aes_decrypt should be " +
-        "one of 16, 24 or 32 bytes, but got: $actualLength")
+    new RuntimeException("The key length of aes_encrypt/aes_decrypt should be 
" +
+      s"one of 16, 24 or 32 bytes, but got: $actualLength")
+  }
+
+  def aesModeUnsupportedError(mode: String, padding: String): RuntimeException 
= {
+    new UnsupportedOperationException(
+      s"The AES mode $mode with the padding $padding is not supported")
   }
 
   def hiveTableWithAnsiIntervalsError(tableName: String): Throwable = {
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 25dbffa..06c165f 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -12,7 +12,7 @@
 | org.apache.spark.sql.catalyst.expressions.Add | + | SELECT 1 + 2 | struct<(1 
+ 2):int> |
 | org.apache.spark.sql.catalyst.expressions.AddMonths | add_months | SELECT 
add_months('2016-08-31', 1) | struct<add_months(2016-08-31, 1):date> |
 | org.apache.spark.sql.catalyst.expressions.AesDecrypt | aes_decrypt | SELECT 
aes_decrypt(unbase64('4Hv0UKCx6nfUeAoPZo1z+w=='), 'abcdefghijklmnop') | 
struct<aesdecrypt(unbase64(4Hv0UKCx6nfUeAoPZo1z+w==), abcdefghijklmnop):binary> 
|
-| org.apache.spark.sql.catalyst.expressions.AesEncrypt | aes_encrypt | SELECT 
base64(aes_encrypt('Spark', 'abcdefghijklmnop')) | 
struct<base64(aesencrypt(Spark, abcdefghijklmnop)):string> |
+| org.apache.spark.sql.catalyst.expressions.AesEncrypt | aes_encrypt | SELECT 
base64(aes_encrypt('Spark', 'abcdefghijklmnop')) | 
struct<base64(aesencrypt(Spark, abcdefghijklmnop, ECB, PKCS)):string> |
 | org.apache.spark.sql.catalyst.expressions.And | and | SELECT true and true | 
struct<(true AND true):boolean> |
 | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | 
SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | 
struct<aggregate(array(1, 2, 3), 0, lambdafunction((namedlambdavariable() + 
namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), 
lambdafunction(namedlambdavariable(), namedlambdavariable())):int> |
 | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | 
SELECT array_contains(array(1, 2, 3), 2) | struct<array_contains(array(1, 2, 
3), 2):boolean> |
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index b25bd28..91f5625 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -269,6 +269,14 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
         "The key length of aes_encrypt/aes_decrypt should be one of 16, 24 or 
32 bytes"))
     }
 
+    def checkUnsupportedMode(df: => DataFrame): Unit = {
+      val e = intercept[SparkException] {
+        df.collect
+      }.getCause
+      assert(e.isInstanceOf[UnsupportedOperationException])
+      assert(e.getMessage.matches("""The AES mode \w+ with the padding \w+ is 
not supported"""))
+    }
+
     val df1 = Seq("Spark", "").toDF
 
     // Successful encryption
@@ -312,6 +320,10 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
     checkInvalidKeyLength(df1.selectExpr("aes_encrypt(value, 
binary('123456789012345'))"))
     checkInvalidKeyLength(df1.selectExpr("aes_encrypt(value, binary(''))"))
 
+    // Unsupported AES mode and padding in encrypt
+    checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 
'CBC')"))
+    checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'ECB', 
'NoPadding')"))
+
     val df2 = Seq(
       (encryptedText16, encryptedText24, encryptedText32),
       (encryptedEmptyText16, encryptedEmptyText24, encryptedEmptyText32)
@@ -372,6 +384,10 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
         }
         assert(e.getMessage.contains("BadPaddingException"))
     }
+
+    // Unsupported AES mode and padding in decrypt
+    checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 
'GSM')"))
+    checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 
'ECB', 'None')"))
   }
 
   test("string function find_in_set") {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to