This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 58060e2be fix: Normalize s3 paths for PME key retriever (#2874)
58060e2be is described below

commit 58060e2be01b4e4c171d62bbe4cf881d00e37fdb
Author: Matt Butrovich <[email protected]>
AuthorDate: Wed Dec 10 20:54:29 2025 -0500

    fix: Normalize s3 paths for PME key retriever (#2874)
---
 .../comet/parquet/CometFileKeyUnwrapper.java       | 27 +++++++++++++-
 .../comet/parquet/ParquetReadFromS3Suite.scala     | 43 ++++++++++++++++++++++
 2 files changed, 68 insertions(+), 2 deletions(-)

diff --git 
a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java 
b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
index 0911901d2..1e71c25a0 100644
--- a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
+++ b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
@@ -101,6 +101,27 @@ public class CometFileKeyUnwrapper {
   // Cache the hadoopConf just to assert the assumption above.
   private Configuration conf = null;
 
+  /**
+   * Normalizes S3 URI schemes to a canonical form. S3 can be accessed via 
multiple schemes (s3://,
+   * s3a://, s3n://) that refer to the same logical filesystem. This method 
ensures consistent cache
+   * lookups regardless of which scheme is used.
+   *
+   * @param filePath The file path that may contain an S3 URI
+   * @return The file path with normalized S3 scheme (s3a://)
+   */
+  private String normalizeS3Scheme(final String filePath) {
+    // Normalize s3:// and s3n:// to s3a:// for consistent cache lookups
+    // This handles the case where ObjectStoreUrl uses s3:// but Spark uses 
s3a://
+    String s3Prefix = "s3://";
+    String s3nPrefix = "s3n://";
+    if (filePath.startsWith(s3Prefix)) {
+      return "s3a://" + filePath.substring(s3Prefix.length());
+    } else if (filePath.startsWith(s3nPrefix)) {
+      return "s3a://" + filePath.substring(s3nPrefix.length());
+    }
+    return filePath;
+  }
+
   /**
    * Creates and stores a DecryptionKeyRetriever instance for the given file 
path.
    *
@@ -108,6 +129,7 @@ public class CometFileKeyUnwrapper {
    * @param hadoopConf The Hadoop Configuration to use for this file path
    */
   public void storeDecryptionKeyRetriever(final String filePath, final 
Configuration hadoopConf) {
+    final String normalizedPath = normalizeS3Scheme(filePath);
     // Use DecryptionPropertiesFactory.loadFactory to get the factory and then 
call
     // getFileDecryptionProperties
     if (factory == null) {
@@ -122,7 +144,7 @@ public class CometFileKeyUnwrapper {
         factory.getFileDecryptionProperties(hadoopConf, path);
 
     DecryptionKeyRetriever keyRetriever = 
decryptionProperties.getKeyRetriever();
-    retrieverCache.put(filePath, keyRetriever);
+    retrieverCache.put(normalizedPath, keyRetriever);
   }
 
   /**
@@ -136,7 +158,8 @@ public class CometFileKeyUnwrapper {
    */
   public byte[] getKey(final String filePath, final byte[] keyMetadata)
       throws ParquetCryptoRuntimeException {
-    DecryptionKeyRetriever keyRetriever = retrieverCache.get(filePath);
+    final String normalizedPath = normalizeS3Scheme(filePath);
+    DecryptionKeyRetriever keyRetriever = retrieverCache.get(normalizedPath);
     if (keyRetriever == null) {
       throw new ParquetCryptoRuntimeException(
           "Failed to find DecryptionKeyRetriever for path: " + filePath);
diff --git 
a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala 
b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala
index 0fd512c61..a42a025c2 100644
--- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala
+++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala
@@ -19,6 +19,12 @@
 
 package org.apache.comet.parquet
 
+import java.nio.charset.StandardCharsets
+import java.util.Base64
+
+import org.apache.parquet.crypto.DecryptionPropertiesFactory
+import org.apache.parquet.crypto.keytools.{KeyToolkit, 
PropertiesDrivenCryptoFactory}
+import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
 import org.apache.spark.sql.{DataFrame, SaveMode}
 import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -30,6 +36,15 @@ class ParquetReadFromS3Suite extends CometS3TestBase with 
AdaptiveSparkPlanHelpe
 
   override protected val testBucketName = "test-bucket"
 
+  // Encryption keys for testing parquet encryption
+  private val encoder = Base64.getEncoder
+  private val footerKey =
+    encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8))
+  private val key1 = 
encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8))
+  private val key2 = 
encoder.encodeToString("1234567890123451".getBytes(StandardCharsets.UTF_8))
+  private val cryptoFactoryClass =
+    "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory"
+
   private def writeTestParquetFile(filePath: String): Unit = {
     val df = spark.range(0, 1000)
     df.write.format("parquet").mode(SaveMode.Overwrite).save(filePath)
@@ -76,4 +91,32 @@ class ParquetReadFromS3Suite extends CometS3TestBase with 
AdaptiveSparkPlanHelpe
     assertCometScan(df)
     assert(df.first().getLong(0) == 499500)
   }
+
+  test("write and read encrypted parquet from S3") {
+    import testImplicits._
+
+    withSQLConf(
+      DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> 
cryptoFactoryClass,
+      KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+        "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+      InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+        s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+      val inputDF = spark
+        .range(0, 1000)
+        .map(i => (i, i.toString, i.toFloat))
+        .repartition(5)
+        .toDF("a", "b", "c")
+
+      val testFilePath = s"s3a://$testBucketName/data/encrypted-test.parquet"
+      inputDF.write
+        .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, 
"key1: a, b; key2: c")
+        .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, 
"footerKey")
+        .parquet(testFilePath)
+
+      val df = spark.read.parquet(testFilePath).agg(sum(col("a")))
+      assertCometScan(df)
+      assert(df.first().getLong(0) == 499500)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to