This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 58060e2be fix: Normalize s3 paths for PME key retriever (#2874)
58060e2be is described below
commit 58060e2be01b4e4c171d62bbe4cf881d00e37fdb
Author: Matt Butrovich <[email protected]>
AuthorDate: Wed Dec 10 20:54:29 2025 -0500
fix: Normalize s3 paths for PME key retriever (#2874)
---
.../comet/parquet/CometFileKeyUnwrapper.java | 27 +++++++++++++-
.../comet/parquet/ParquetReadFromS3Suite.scala | 43 ++++++++++++++++++++++
2 files changed, 68 insertions(+), 2 deletions(-)
diff --git
a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
index 0911901d2..1e71c25a0 100644
--- a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
+++ b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
@@ -101,6 +101,27 @@ public class CometFileKeyUnwrapper {
// Cache the hadoopConf just to assert the assumption above.
private Configuration conf = null;
+ /**
+ * Normalizes S3 URI schemes to a canonical form. S3 can be accessed via
multiple schemes (s3://,
+ * s3a://, s3n://) that refer to the same logical filesystem. This method
ensures consistent cache
+ * lookups regardless of which scheme is used.
+ *
+ * @param filePath The file path that may contain an S3 URI
+ * @return The file path with normalized S3 scheme (s3a://)
+ */
+ private String normalizeS3Scheme(final String filePath) {
+ // Normalize s3:// and s3n:// to s3a:// for consistent cache lookups
+ // This handles the case where ObjectStoreUrl uses s3:// but Spark uses
s3a://
+ String s3Prefix = "s3://";
+ String s3nPrefix = "s3n://";
+ if (filePath.startsWith(s3Prefix)) {
+ return "s3a://" + filePath.substring(s3Prefix.length());
+ } else if (filePath.startsWith(s3nPrefix)) {
+ return "s3a://" + filePath.substring(s3nPrefix.length());
+ }
+ return filePath;
+ }
+
/**
* Creates and stores a DecryptionKeyRetriever instance for the given file
path.
*
@@ -108,6 +129,7 @@ public class CometFileKeyUnwrapper {
* @param hadoopConf The Hadoop Configuration to use for this file path
*/
public void storeDecryptionKeyRetriever(final String filePath, final
Configuration hadoopConf) {
+ final String normalizedPath = normalizeS3Scheme(filePath);
// Use DecryptionPropertiesFactory.loadFactory to get the factory and then
call
// getFileDecryptionProperties
if (factory == null) {
@@ -122,7 +144,7 @@ public class CometFileKeyUnwrapper {
factory.getFileDecryptionProperties(hadoopConf, path);
DecryptionKeyRetriever keyRetriever =
decryptionProperties.getKeyRetriever();
- retrieverCache.put(filePath, keyRetriever);
+ retrieverCache.put(normalizedPath, keyRetriever);
}
/**
@@ -136,7 +158,8 @@ public class CometFileKeyUnwrapper {
*/
public byte[] getKey(final String filePath, final byte[] keyMetadata)
throws ParquetCryptoRuntimeException {
- DecryptionKeyRetriever keyRetriever = retrieverCache.get(filePath);
+ final String normalizedPath = normalizeS3Scheme(filePath);
+ DecryptionKeyRetriever keyRetriever = retrieverCache.get(normalizedPath);
if (keyRetriever == null) {
throw new ParquetCryptoRuntimeException(
"Failed to find DecryptionKeyRetriever for path: " + filePath);
diff --git
a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala
b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala
index 0fd512c61..a42a025c2 100644
--- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala
+++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala
@@ -19,6 +19,12 @@
package org.apache.comet.parquet
+import java.nio.charset.StandardCharsets
+import java.util.Base64
+
+import org.apache.parquet.crypto.DecryptionPropertiesFactory
+import org.apache.parquet.crypto.keytools.{KeyToolkit,
PropertiesDrivenCryptoFactory}
+import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -30,6 +36,15 @@ class ParquetReadFromS3Suite extends CometS3TestBase with
AdaptiveSparkPlanHelpe
override protected val testBucketName = "test-bucket"
+ // Encryption keys for testing parquet encryption
+ private val encoder = Base64.getEncoder
+ private val footerKey =
+ encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8))
+ private val key1 =
encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8))
+ private val key2 =
encoder.encodeToString("1234567890123451".getBytes(StandardCharsets.UTF_8))
+ private val cryptoFactoryClass =
+ "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory"
+
private def writeTestParquetFile(filePath: String): Unit = {
val df = spark.range(0, 1000)
df.write.format("parquet").mode(SaveMode.Overwrite).save(filePath)
@@ -76,4 +91,32 @@ class ParquetReadFromS3Suite extends CometS3TestBase with
AdaptiveSparkPlanHelpe
assertCometScan(df)
assert(df.first().getLong(0) == 499500)
}
+
+ test("write and read encrypted parquet from S3") {
+ import testImplicits._
+
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+ val inputDF = spark
+ .range(0, 1000)
+ .map(i => (i, i.toString, i.toFloat))
+ .repartition(5)
+ .toDF("a", "b", "c")
+
+ val testFilePath = s"s3a://$testBucketName/data/encrypted-test.parquet"
+ inputDF.write
+ .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
"key1: a, b; key2: c")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(testFilePath)
+
+ val df = spark.read.parquet(testFilePath).agg(sum(col("a")))
+ assertCometScan(df)
+ assert(df.first().getLong(0) == 499500)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]