This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c23dc2524 feat: Parquet Modular Encryption with Spark KMS for native
readers (#2447)
c23dc2524 is described below
commit c23dc25242f2753baa98f47407dd74691dd8d2a5
Author: Matt Butrovich <[email protected]>
AuthorDate: Mon Oct 6 22:05:29 2025 -0400
feat: Parquet Modular Encryption with Spark KMS for native readers (#2447)
---
.../comet/parquet/CometFileKeyUnwrapper.java | 146 ++++++
.../main/java/org/apache/comet/parquet/Native.java | 5 +-
.../apache/comet/parquet/NativeBatchReader.java | 14 +-
.../apache/comet/objectstore/NativeConfig.scala | 2 +-
.../apache/comet/parquet/CometParquetUtils.scala | 44 ++
native/Cargo.lock | 4 +
native/core/Cargo.toml | 2 +-
native/core/src/errors.rs | 9 +
native/core/src/execution/jni_api.rs | 13 +
native/core/src/execution/planner.rs | 2 +
native/core/src/parquet/encryption_support.rs | 172 +++++++
native/core/src/parquet/mod.rs | 23 +
native/core/src/parquet/parquet_exec.rs | 33 +-
native/proto/src/proto/operator.proto | 1 +
.../scala/org/apache/comet/CometExecIterator.scala | 30 +-
spark/src/main/scala/org/apache/comet/Native.scala | 5 +-
.../org/apache/comet/rules/CometScanRule.scala | 36 +-
.../org/apache/comet/serde/QueryPlanSerde.scala | 5 +
.../shuffle/CometNativeShuffleWriter.scala | 4 +-
.../org/apache/spark/sql/comet/operators.scala | 65 ++-
.../spark/sql/benchmark/CometBenchmarkBase.scala | 40 ++
.../spark/sql/benchmark/CometReadBenchmark.scala | 103 +++-
.../spark/sql/comet/ParquetEncryptionITCase.scala | 528 +++++++++++++++++----
23 files changed, 1165 insertions(+), 121 deletions(-)
diff --git
a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
new file mode 100644
index 000000000..0911901d2
--- /dev/null
+++ b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.parquet;
+
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.parquet.crypto.DecryptionKeyRetriever;
+import org.apache.parquet.crypto.DecryptionPropertiesFactory;
+import org.apache.parquet.crypto.FileDecryptionProperties;
+import org.apache.parquet.crypto.ParquetCryptoRuntimeException;
+
+// spotless:off
+/*
+ * Architecture Overview:
+ *
+ * JVM Side | Native
Side
+ * ┌─────────────────────────────────────┐ |
┌─────────────────────────────────────┐
+ * │ CometFileKeyUnwrapper │ | │ Parquet File
Reading │
+ * │ │ | │
│
+ * │ ┌─────────────────────────────┐ │ | │
┌─────────────────────────────┐ │
+ * │ │ hadoopConf │ │ | │ │ file1.parquet
│ │
+ * │ │ (Configuration) │ │ | │ │ file2.parquet
│ │
+ * │ └─────────────────────────────┘ │ | │ │ file3.parquet
│ │
+ * │ │ │ | │
└─────────────────────────────┘ │
+ * │ ▼ │ | │ │
│
+ * │ ┌─────────────────────────────┐ │ | │ │
│
+ * │ │ factoryCache │ │ | │ ▼
│
+ * │ │ (many-to-one mapping) │ │ | │
┌─────────────────────────────┐ │
+ * │ │ │ │ | │ │ Parse file metadata &
│ │
+ * │ │ file1 ──┐ │ │ | │ │ extract keyMetadata
│ │
+ * │ │ file2 ──┼─► DecryptionProps │ │ | │
└─────────────────────────────┘ │
+ * │ │ file3 ──┘ Factory │ │ | │ │
│
+ * │ └─────────────────────────────┘ │ | │ │
│
+ * │ │ │ | │ ▼
│
+ * │ ▼ │ | │
╔═════════════════════════════╗ │
+ * │ ┌─────────────────────────────┐ │ | │ ║ JNI CALL:
║ │
+ * │ │ retrieverCache │ │ | │ ║ getKey(filePath,
║ │
+ * │ │ filePath -> KeyRetriever │◄───┼───┼───┼──║ keyMetadata)
║ │
+ * │ └─────────────────────────────┘ │ | │
╚═════════════════════════════╝ │
+ * │ │ │ | │
│
+ * │ ▼ │ | │
│
+ * │ ┌─────────────────────────────┐ │ | │
│
+ * │ │ DecryptionKeyRetriever │ │ | │
│
+ * │ │ .getKey(keyMetadata) │ │ | │
│
+ * │ └─────────────────────────────┘ │ | │
│
+ * │ │ │ | │
│
+ * │ ▼ │ | │
│
+ * │ ┌─────────────────────────────┐ │ | │
┌─────────────────────────────┐ │
+ * │ │ return key bytes │────┼───┼───┼─►│ Use key for
decryption │ │
+ * │ └─────────────────────────────┘ │ | │ │ of parquet data
│ │
+ * └─────────────────────────────────────┘ | │
└─────────────────────────────┘ │
+ * |
└─────────────────────────────────────┘
+ * |
+ * JNI Boundary
+ *
+ * Setup Phase (storeDecryptionKeyRetriever):
+ * 1. hadoopConf → DecryptionPropertiesFactory (cached in factoryCache)
+ * 2. Factory + filePath → DecryptionKeyRetriever (cached in retrieverCache)
+ *
+ * Runtime Phase (getKey):
+ * 3. Native code calls getKey(filePath, keyMetadata) ──► JVM
+ * 4. Retrieve cached DecryptionKeyRetriever for filePath
+ * 5. KeyRetriever.getKey(keyMetadata) → decrypted key bytes
+ * 6. Return key bytes ──► Native code for parquet decryption
+ */
+// spotless:on
+
+/**
+ * Helper class to access DecryptionKeyRetriever.getKey from native code via
JNI. This class handles
+ * the complexity of creating and caching properly configured
DecryptionKeyRetriever instances using
+ * DecryptionPropertiesFactory. The life of this object is meant to map to a
single Comet plan, so
+ * associated with CometExecIterator.
+ */
+public class CometFileKeyUnwrapper {
+
+ // Each file path gets a unique DecryptionKeyRetriever
+ private final ConcurrentHashMap<String, DecryptionKeyRetriever>
retrieverCache =
+ new ConcurrentHashMap<>();
+
+ // Cache the factory since we should be using the same hadoopConf for every
file in this scan.
+ private DecryptionPropertiesFactory factory = null;
+ // Cache the hadoopConf just to assert the assumption above.
+ private Configuration conf = null;
+
+ /**
+ * Creates and stores a DecryptionKeyRetriever instance for the given file
path.
+ *
+ * @param filePath The path to the Parquet file
+ * @param hadoopConf The Hadoop Configuration to use for this file path
+ */
+ public void storeDecryptionKeyRetriever(final String filePath, final
Configuration hadoopConf) {
+ // Use DecryptionPropertiesFactory.loadFactory to get the factory and then
call
+ // getFileDecryptionProperties
+ if (factory == null) {
+ factory = DecryptionPropertiesFactory.loadFactory(hadoopConf);
+ conf = hadoopConf;
+ } else {
+ // Check the assumption that all files have the same hadoopConf and thus
same Factory
+ assert (conf == hadoopConf);
+ }
+ Path path = new Path(filePath);
+ FileDecryptionProperties decryptionProperties =
+ factory.getFileDecryptionProperties(hadoopConf, path);
+
+ DecryptionKeyRetriever keyRetriever =
decryptionProperties.getKeyRetriever();
+ retrieverCache.put(filePath, keyRetriever);
+ }
+
+ /**
+ * Gets the decryption key for the given key metadata using the cached
DecryptionKeyRetriever for
+ * the specified file path.
+ *
+ * @param filePath The path to the Parquet file
+ * @param keyMetadata The key metadata bytes from the Parquet file
+ * @return The decrypted key bytes
+ * @throws ParquetCryptoRuntimeException if key unwrapping fails
+ */
+ public byte[] getKey(final String filePath, final byte[] keyMetadata)
+ throws ParquetCryptoRuntimeException {
+ DecryptionKeyRetriever keyRetriever = retrieverCache.get(filePath);
+ if (keyRetriever == null) {
+ throw new ParquetCryptoRuntimeException(
+ "Failed to find DecryptionKeyRetriever for path: " + filePath);
+ }
+ return keyRetriever.getKey(keyMetadata);
+ }
+}
diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java
b/common/src/main/java/org/apache/comet/parquet/Native.java
index cceb1085c..dbddc3b74 100644
--- a/common/src/main/java/org/apache/comet/parquet/Native.java
+++ b/common/src/main/java/org/apache/comet/parquet/Native.java
@@ -267,9 +267,11 @@ public final class Native extends NativeBase {
String sessionTimezone,
int batchSize,
boolean caseSensitive,
- Map<String, String> objectStoreOptions);
+ Map<String, String> objectStoreOptions,
+ CometFileKeyUnwrapper keyUnwrapper);
// arrow native version of read batch
+
/**
* Read the next batch of data into memory on native side
*
@@ -280,6 +282,7 @@ public final class Native extends NativeBase {
// arrow native equivalent of currentBatch. 'columnNum' is number of the
column in the record
// batch
+
/**
* Load the column corresponding to columnNum in the currently loaded record
batch into JVM
*
diff --git
a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
index 67c277540..84918d933 100644
--- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
+++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
@@ -80,7 +80,7 @@ import org.apache.comet.shims.ShimFileFormat;
import org.apache.comet.vector.CometVector;
import org.apache.comet.vector.NativeUtil;
-import static scala.jdk.javaapi.CollectionConverters.*;
+import static scala.jdk.javaapi.CollectionConverters.asJava;
/**
* A vectorized Parquet reader that reads a Parquet file in a batched fashion.
@@ -410,6 +410,15 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
}
}
+ boolean encryptionEnabled = CometParquetUtils.encryptionEnabled(conf);
+
+ // Create keyUnwrapper if encryption is enabled
+ CometFileKeyUnwrapper keyUnwrapper = null;
+ if (encryptionEnabled) {
+ keyUnwrapper = new CometFileKeyUnwrapper();
+ keyUnwrapper.storeDecryptionKeyRetriever(file.filePath().toString(),
conf);
+ }
+
int batchSize =
conf.getInt(
CometConf.COMET_BATCH_SIZE().key(),
@@ -426,7 +435,8 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
timeZoneId,
batchSize,
caseSensitive,
- objectStoreOptions);
+ objectStoreOptions,
+ keyUnwrapper);
}
isInitialized = true;
}
diff --git
a/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala
b/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala
index b930aea17..885b4686e 100644
--- a/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala
+++ b/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala
@@ -58,7 +58,7 @@ object NativeConfig {
def extractObjectStoreOptions(hadoopConf: Configuration, uri: URI):
Map[String, String] = {
val scheme = uri.getScheme.toLowerCase(Locale.ROOT)
- import scala.collection.JavaConverters._
+ import scala.jdk.CollectionConverters._
val options = scala.collection.mutable.Map[String, String]()
// The schemes will use libhdfs
diff --git
a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
index a37ec7e66..8bcf99dbd 100644
--- a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
+++ b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
@@ -20,6 +20,8 @@
package org.apache.comet.parquet
import org.apache.hadoop.conf.Configuration
+import org.apache.parquet.crypto.DecryptionPropertiesFactory
+import org.apache.parquet.crypto.keytools.{KeyToolkit,
PropertiesDrivenCryptoFactory}
import org.apache.spark.sql.internal.SQLConf
object CometParquetUtils {
@@ -27,6 +29,16 @@ object CometParquetUtils {
private val PARQUET_FIELD_ID_READ_ENABLED =
"spark.sql.parquet.fieldId.read.enabled"
private val IGNORE_MISSING_PARQUET_FIELD_ID =
"spark.sql.parquet.fieldId.read.ignoreMissing"
+ // Map of encryption configuration key-value pairs that, if present, are
only supported with
+ // these specific values. Generally, these are the default values that won't
be present,
+ // but if they are present we want to check them.
+ private val SUPPORTED_ENCRYPTION_CONFIGS: Map[String, Set[String]] = Map(
+ //
https://github.com/apache/arrow-rs/blob/main/parquet/src/encryption/ciphers.rs#L21
+ KeyToolkit.DATA_KEY_LENGTH_PROPERTY_NAME ->
Set(KeyToolkit.DATA_KEY_LENGTH_DEFAULT.toString),
+ KeyToolkit.KEK_LENGTH_PROPERTY_NAME ->
Set(KeyToolkit.KEK_LENGTH_DEFAULT.toString),
+ //
https://github.com/apache/arrow-rs/blob/main/parquet/src/file/metadata/parser.rs#L494
+ PropertiesDrivenCryptoFactory.ENCRYPTION_ALGORITHM_PROPERTY_NAME ->
Set("AES_GCM_V1"))
+
def writeFieldId(conf: SQLConf): Boolean =
conf.getConfString(PARQUET_FIELD_ID_WRITE_ENABLED, "false").toBoolean
@@ -38,4 +50,36 @@ object CometParquetUtils {
def ignoreMissingIds(conf: SQLConf): Boolean =
conf.getConfString(IGNORE_MISSING_PARQUET_FIELD_ID, "false").toBoolean
+
+ /**
+ * Checks if the given Hadoop configuration contains any unsupported
encryption settings.
+ *
+ * @param hadoopConf
+ * The Hadoop configuration to check
+ * @return
+ * true if all encryption configurations are supported, false if any
unsupported config is
+ * found
+ */
+ def isEncryptionConfigSupported(hadoopConf: Configuration): Boolean = {
+ // Check configurations that, if present, can only have specific allowed
values
+ val supportedListCheck = SUPPORTED_ENCRYPTION_CONFIGS.forall {
+ case (configKey, supportedValues) =>
+ val configValue = Option(hadoopConf.get(configKey))
+ configValue match {
+ case Some(value) => supportedValues.contains(value)
+ case None => true // Config not set, so it's supported
+ }
+ }
+
+ supportedListCheck
+ }
+
+ def encryptionEnabled(hadoopConf: Configuration): Boolean = {
+ // TODO: Are there any other properties to check?
+ val encryptionKeys = Seq(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME)
+
+ encryptionKeys.exists(key =>
Option(hadoopConf.get(key)).exists(_.nonEmpty))
+ }
}
diff --git a/native/Cargo.lock b/native/Cargo.lock
index 483d2e070..ad8c24d9d 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -1424,6 +1424,7 @@ dependencies = [
"datafusion-session",
"datafusion-sql",
"futures",
+ "hex",
"itertools 0.14.0",
"log",
"object_store",
@@ -1611,6 +1612,7 @@ dependencies = [
"chrono",
"half",
"hashbrown 0.14.5",
+ "hex",
"indexmap",
"libc",
"log",
@@ -1738,6 +1740,7 @@ dependencies = [
"datafusion-pruning",
"datafusion-session",
"futures",
+ "hex",
"itertools 0.14.0",
"log",
"object_store",
@@ -1768,6 +1771,7 @@ dependencies = [
"log",
"object_store",
"parking_lot",
+ "parquet",
"rand",
"tempfile",
"url",
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index aa4425c96..e6e4c6f3c 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -59,7 +59,7 @@ bytes = { workspace = true }
tempfile = "3.8.0"
itertools = "0.14.0"
paste = "1.0.14"
-datafusion = { workspace = true }
+datafusion = { workspace = true, features = ["parquet_encryption"] }
datafusion-spark = { workspace = true }
once_cell = "1.18.0"
regex = { workspace = true }
diff --git a/native/core/src/errors.rs b/native/core/src/errors.rs
index b3241477b..ecac7af94 100644
--- a/native/core/src/errors.rs
+++ b/native/core/src/errors.rs
@@ -185,6 +185,15 @@ impl From<CometError> for DataFusionError {
}
}
+impl From<CometError> for ParquetError {
+ fn from(value: CometError) -> Self {
+ match value {
+ CometError::Parquet { source } => source,
+ _ => ParquetError::General(value.to_string()),
+ }
+ }
+}
+
impl From<CometError> for ExecutionError {
fn from(value: CometError) -> Self {
match value {
diff --git a/native/core/src/execution/jni_api.rs
b/native/core/src/execution/jni_api.rs
index 83dbd68e7..b76108ad9 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -78,6 +78,7 @@ use crate::execution::spark_plan::SparkPlan;
use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end,
with_trace};
+use crate::parquet::encryption_support::{CometEncryptionFactory,
ENCRYPTION_FACTORY_ID};
use datafusion_comet_proto::spark_operator::operator::OpStruct;
use log::info;
use once_cell::sync::Lazy;
@@ -171,6 +172,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
explain_native: jboolean,
tracing_enabled: jboolean,
max_temp_directory_size: jlong,
+ key_unwrapper_obj: JObject,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
with_trace("createPlan", tracing_enabled != JNI_FALSE, || {
@@ -247,6 +249,17 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
None
};
+ // Handle key unwrapper for encrypted files
+ if !key_unwrapper_obj.is_null() {
+ let encryption_factory = CometEncryptionFactory {
+ key_unwrapper: jni_new_global_ref!(env,
key_unwrapper_obj)?,
+ };
+ session.runtime_env().register_parquet_encryption_factory(
+ ENCRYPTION_FACTORY_ID,
+ Arc::new(encryption_factory),
+ );
+ }
+
let exec_context = Box::new(ExecutionContext {
id,
task_attempt_id,
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index 517c037e9..329edc5d2 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -1358,6 +1358,8 @@ impl PhysicalPlanner {
default_values,
scan.session_timezone.as_str(),
scan.case_sensitive,
+ self.session_ctx(),
+ scan.encryption_enabled,
)?;
Ok((
vec![],
diff --git a/native/core/src/parquet/encryption_support.rs
b/native/core/src/parquet/encryption_support.rs
new file mode 100644
index 000000000..ff67a3fcb
--- /dev/null
+++ b/native/core/src/parquet/encryption_support.rs
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::execution::operators::ExecutionError;
+use crate::jvm_bridge::{check_exception, JVMClasses};
+use arrow::datatypes::SchemaRef;
+use async_trait::async_trait;
+use datafusion::common::extensions_options;
+use datafusion::config::EncryptionFactoryOptions;
+use datafusion::error::DataFusionError;
+use datafusion::execution::parquet_encryption::EncryptionFactory;
+use jni::objects::{GlobalRef, JMethodID};
+use object_store::path::Path;
+use parquet::encryption::decrypt::{FileDecryptionProperties, KeyRetriever};
+use parquet::encryption::encrypt::FileEncryptionProperties;
+use parquet::errors::ParquetError;
+use std::sync::Arc;
+
+pub const ENCRYPTION_FACTORY_ID: &str = "comet.jni_kms_encryption";
+
+extensions_options! {
+ pub struct CometEncryptionConfig {
+ // Native side strips file down to a path (not a URI) but Spark wants
the full URI,
+ // so we cache the prefix to stick on the front before calling over JNI
+ pub uri_base: String, default = "file:///".into()
+ }
+}
+
+#[derive(Debug)]
+pub struct CometEncryptionFactory {
+ pub(crate) key_unwrapper: GlobalRef,
+}
+
+/// `EncryptionFactory` is a DataFusion trait for types that generate
+/// file encryption and decryption properties.
+#[async_trait]
+impl EncryptionFactory for CometEncryptionFactory {
+ async fn get_file_encryption_properties(
+ &self,
+ _options: &EncryptionFactoryOptions,
+ _schema: &SchemaRef,
+ _file_path: &Path,
+ ) -> Result<Option<FileEncryptionProperties>, DataFusionError> {
+ Err(DataFusionError::NotImplemented(
+ "Comet does not support Parquet encryption yet."
+ .parse()
+ .unwrap(),
+ ))
+ }
+
+ /// Generate file decryption properties to use when reading a Parquet file.
+ /// Rather than provide the AES keys directly for decryption, we set a
`KeyRetriever`
+ /// that can determine the keys using the encryption metadata.
+ async fn get_file_decryption_properties(
+ &self,
+ options: &EncryptionFactoryOptions,
+ file_path: &Path,
+ ) -> Result<Option<FileDecryptionProperties>, DataFusionError> {
+ let config: CometEncryptionConfig = options.to_extension_options()?;
+
+ let full_path: String = config.uri_base + file_path.as_ref();
+ let key_retriever = CometKeyRetriever::new(&full_path,
self.key_unwrapper.clone())
+ .map_err(|e| DataFusionError::External(Box::new(e)))?;
+ let decryption_properties =
+
FileDecryptionProperties::with_key_retriever(Arc::new(key_retriever)).build()?;
+ Ok(Some(decryption_properties))
+ }
+}
+
+pub struct CometKeyRetriever {
+ file_path: String,
+ key_unwrapper: GlobalRef,
+ get_key_method_id: JMethodID,
+}
+
+impl CometKeyRetriever {
+ pub fn new(file_path: &str, key_unwrapper: GlobalRef) -> Result<Self,
ExecutionError> {
+ let mut env = JVMClasses::get_env()?;
+
+ Ok(CometKeyRetriever {
+ file_path: file_path.to_string(),
+ key_unwrapper,
+ get_key_method_id: env
+ .get_method_id(
+ "org/apache/comet/parquet/CometFileKeyUnwrapper",
+ "getKey",
+ "(Ljava/lang/String;[B)[B",
+ )
+ .map_err(|e| {
+ ExecutionError::GeneralError(format!("Failed to get JNI
method ID: {}", e))
+ })?,
+ })
+ }
+}
+
+impl KeyRetriever for CometKeyRetriever {
+ /// Get a data encryption key using the metadata stored in the Parquet
file.
+ fn retrieve_key(&self, key_metadata: &[u8]) ->
datafusion::parquet::errors::Result<Vec<u8>> {
+ use jni::{objects::JObject, signature::ReturnType};
+
+ // Get JNI environment
+ let mut env = JVMClasses::get_env()?;
+
+ // Get the key unwrapper instance from GlobalRef
+ let unwrapper_instance = self.key_unwrapper.as_obj();
+
+ let instance: JObject = unsafe {
JObject::from_raw(unwrapper_instance.as_raw()) };
+
+ // Convert file path to JString
+ let file_path_jstring = env
+ .new_string(&self.file_path)
+ .map_err(|e| ParquetError::General(format!("Failed to create
JString: {}", e)))?;
+
+ // Convert key_metadata to JByteArray
+ let key_metadata_array = env
+ .byte_array_from_slice(key_metadata)
+ .map_err(|e| ParquetError::General(format!("Failed to create byte
array: {}", e)))?;
+
+ // Call instance method FileKeyUnwrapper.getKey(String, byte[]) ->
byte[]
+ let result = unsafe {
+ env.call_method_unchecked(
+ instance,
+ self.get_key_method_id,
+ ReturnType::Array,
+ &[
+ jni::objects::JValue::from(&file_path_jstring).as_jni(),
+ jni::objects::JValue::from(&key_metadata_array).as_jni(),
+ ],
+ )
+ };
+
+ // Check for Java exceptions first, before processing the result
+ if let Some(exception) = check_exception(&mut env).map_err(|e| {
+ ParquetError::General(format!("Failed to check for Java exception:
{}", e))
+ })? {
+ return Err(ParquetError::General(format!(
+ "Java exception during key retrieval: {}",
+ exception
+ )));
+ }
+
+ let result =
+ result.map_err(|e| ParquetError::General(format!("JNI method call
failed: {}", e)))?;
+
+ // Extract the byte array from the result
+ let result_array = result
+ .l()
+ .map_err(|e| ParquetError::General(format!("Failed to extract
result: {}", e)))?;
+
+ // Convert JObject to JByteArray and then to Vec<u8>
+ let byte_array: jni::objects::JByteArray = result_array.into();
+
+ let result_vec = env
+ .convert_byte_array(&byte_array)
+ .map_err(|e| ParquetError::General(format!("Failed to convert byte
array: {}", e)))?;
+ Ok(result_vec)
+ }
+}
diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs
index a6efe4ed5..ca70c2fc3 100644
--- a/native/core/src/parquet/mod.rs
+++ b/native/core/src/parquet/mod.rs
@@ -16,6 +16,7 @@
// under the License.
pub mod data_type;
+pub mod encryption_support;
pub mod mutable_vector;
pub use mutable_vector::*;
@@ -52,7 +53,9 @@ use crate::execution::operators::ExecutionError;
use crate::execution::planner::PhysicalPlanner;
use crate::execution::serde;
use crate::execution::utils::SparkArrowConvert;
+use crate::jvm_bridge::{jni_new_global_ref, JVMClasses};
use crate::parquet::data_type::AsBytes;
+use crate::parquet::encryption_support::{CometEncryptionFactory,
ENCRYPTION_FACTORY_ID};
use crate::parquet::parquet_exec::init_datasource_exec;
use crate::parquet::parquet_support::prepare_object_store_with_configs;
use arrow::array::{Array, RecordBatch};
@@ -712,8 +715,10 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_initRecordBat
batch_size: jint,
case_sensitive: jboolean,
object_store_options: jobject,
+ key_unwrapper_obj: JObject,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| unsafe {
+ JVMClasses::init(&mut env);
let session_config = SessionConfig::new().with_batch_size(batch_size
as usize);
let planner =
PhysicalPlanner::new(Arc::new(SessionContext::new_with_config(session_config)),
0);
@@ -766,6 +771,22 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_initRecordBat
.unwrap()
.into();
+ // Handle key unwrapper for encrypted files
+ let encryption_enabled = if !key_unwrapper_obj.is_null() {
+ let encryption_factory = CometEncryptionFactory {
+ key_unwrapper: jni_new_global_ref!(env, key_unwrapper_obj)?,
+ };
+ session_ctx
+ .runtime_env()
+ .register_parquet_encryption_factory(
+ ENCRYPTION_FACTORY_ID,
+ Arc::new(encryption_factory),
+ );
+ true
+ } else {
+ false
+ };
+
let scan = init_datasource_exec(
required_schema,
Some(data_schema),
@@ -778,6 +799,8 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_initRecordBat
None,
session_timezone.as_str(),
case_sensitive != JNI_FALSE,
+ session_ctx,
+ encryption_enabled,
)?;
let partition_index: usize = 0;
diff --git a/native/core/src/parquet/parquet_exec.rs
b/native/core/src/parquet/parquet_exec.rs
index 4b587b7ba..0a95ec999 100644
--- a/native/core/src/parquet/parquet_exec.rs
+++ b/native/core/src/parquet/parquet_exec.rs
@@ -16,6 +16,7 @@
// under the License.
use crate::execution::operators::ExecutionError;
+use crate::parquet::encryption_support::{CometEncryptionConfig,
ENCRYPTION_FACTORY_ID};
use crate::parquet::parquet_support::SparkParquetOptions;
use crate::parquet::schema_adapter::SparkSchemaAdapterFactory;
use arrow::datatypes::{Field, SchemaRef};
@@ -28,6 +29,7 @@ use datafusion::datasource::source::DataSourceExec;
use datafusion::execution::object_store::ObjectStoreUrl;
use datafusion::physical_expr::expressions::BinaryExpr;
use datafusion::physical_expr::PhysicalExpr;
+use datafusion::prelude::SessionContext;
use datafusion::scalar::ScalarValue;
use datafusion_comet_spark_expr::EvalMode;
use itertools::Itertools;
@@ -66,9 +68,16 @@ pub(crate) fn init_datasource_exec(
default_values: Option<HashMap<usize, ScalarValue>>,
session_timezone: &str,
case_sensitive: bool,
+ session_ctx: &Arc<SessionContext>,
+ encryption_enabled: bool,
) -> Result<Arc<DataSourceExec>, ExecutionError> {
- let (table_parquet_options, spark_parquet_options) =
- get_options(session_timezone, case_sensitive);
+ let (table_parquet_options, spark_parquet_options) = get_options(
+ session_timezone,
+ case_sensitive,
+ &object_store_url,
+ encryption_enabled,
+ );
+
let mut parquet_source = ParquetSource::new(table_parquet_options);
// Create a conjunctive form of the vector because ParquetExecBuilder takes
@@ -87,6 +96,14 @@ pub(crate) fn init_datasource_exec(
}
}
+ if encryption_enabled {
+ parquet_source = parquet_source.with_encryption_factory(
+ session_ctx
+ .runtime_env()
+ .parquet_encryption_factory(ENCRYPTION_FACTORY_ID)?,
+ );
+ }
+
let file_source = parquet_source.with_schema_adapter_factory(Arc::new(
SparkSchemaAdapterFactory::new(spark_parquet_options, default_values),
))?;
@@ -125,6 +142,8 @@ pub(crate) fn init_datasource_exec(
fn get_options(
session_timezone: &str,
case_sensitive: bool,
+ object_store_url: &ObjectStoreUrl,
+ encryption_enabled: bool,
) -> (TableParquetOptions, SparkParquetOptions) {
let mut table_parquet_options = TableParquetOptions::new();
table_parquet_options.global.pushdown_filters = true;
@@ -134,6 +153,16 @@ fn get_options(
SparkParquetOptions::new(EvalMode::Legacy, session_timezone, false);
spark_parquet_options.allow_cast_unsigned_ints = true;
spark_parquet_options.case_sensitive = case_sensitive;
+
+ if encryption_enabled {
+ table_parquet_options.crypto.configure_factory(
+ ENCRYPTION_FACTORY_ID,
+ &CometEncryptionConfig {
+ uri_base: object_store_url.to_string(),
+ },
+ );
+ }
+
(table_parquet_options, spark_parquet_options)
}
diff --git a/native/proto/src/proto/operator.proto
b/native/proto/src/proto/operator.proto
index 57e012b36..a243ab6b0 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -104,6 +104,7 @@ message NativeScan {
// configuration value "spark.hadoop.fs.s3a.access.key" will be stored as
"fs.s3a.access.key" in
// the map.
map<string, string> object_store_options = 13;
+ bool encryption_enabled = 14;
}
message Projection {
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index a4e9494b6..8603a7b9a 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -24,14 +24,18 @@ import java.lang.management.ManagementFactory
import scala.util.matching.Regex
+import org.apache.hadoop.conf.Configuration
import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.comet.CometMetricNode
import org.apache.spark.sql.vectorized._
+import org.apache.spark.util.SerializableConfiguration
-import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED,
COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED,
COMET_METRICS_UPDATE_INTERVAL}
+import org.apache.comet.CometConf._
import org.apache.comet.Tracing.withTrace
+import org.apache.comet.parquet.CometFileKeyUnwrapper
import org.apache.comet.serde.Config.ConfigMap
import org.apache.comet.vector.NativeUtil
@@ -52,6 +56,8 @@ import org.apache.comet.vector.NativeUtil
* The number of partitions.
* @param partitionIndex
* The index of the partition.
+ * @param encryptedFilePaths
+ * Paths to encrypted Parquet files that need key unwrapping.
*/
class CometExecIterator(
val id: Long,
@@ -60,7 +66,9 @@ class CometExecIterator(
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode,
numParts: Int,
- partitionIndex: Int)
+ partitionIndex: Int,
+ broadcastedHadoopConfForEncryption:
Option[Broadcast[SerializableConfiguration]] = None,
+ encryptedFilePaths: Seq[String] = Seq.empty)
extends Iterator[ColumnarBatch]
with Logging {
@@ -73,6 +81,7 @@ class CometExecIterator(
private val cometBatchIterators = inputs.map { iterator =>
new CometBatchIterator(iterator, nativeUtil)
}.toArray
+
private val plan = {
val conf = SparkEnv.get.conf
val localDiskDirs = SparkEnv.get.blockManager.getLocalDiskDirs
@@ -102,6 +111,19 @@ class CometExecIterator(
getMemoryLimitPerTask(conf)
}
+ // Create keyUnwrapper if encryption is enabled
+ val keyUnwrapper = if (encryptedFilePaths.nonEmpty) {
+ val unwrapper = new CometFileKeyUnwrapper()
+ val hadoopConf: Configuration =
broadcastedHadoopConfForEncryption.get.value.value
+
+ encryptedFilePaths.foreach(filePath =>
+ unwrapper.storeDecryptionKeyRetriever(filePath, hadoopConf))
+
+ unwrapper
+ } else {
+ null
+ }
+
nativeLib.createPlan(
id,
cometBatchIterators,
@@ -121,7 +143,8 @@ class CometExecIterator(
debug = COMET_DEBUG_ENABLED.get(),
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
tracingEnabled,
- maxTempDirectorySize = CometConf.COMET_MAX_TEMP_DIRECTORY_SIZE.get())
+ maxTempDirectorySize = CometConf.COMET_MAX_TEMP_DIRECTORY_SIZE.get(),
+ keyUnwrapper)
}
private var nextBatch: Option[ColumnarBatch] = None
@@ -145,6 +168,7 @@ class CometExecIterator(
def convertToInt(threads: String): Int = {
if (threads == "*") Runtime.getRuntime.availableProcessors() else
threads.toInt
}
+
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
val master = conf.get("spark.master")
diff --git a/spark/src/main/scala/org/apache/comet/Native.scala
b/spark/src/main/scala/org/apache/comet/Native.scala
index 13edf2997..fb24dce0d 100644
--- a/spark/src/main/scala/org/apache/comet/Native.scala
+++ b/spark/src/main/scala/org/apache/comet/Native.scala
@@ -24,6 +24,8 @@ import java.nio.ByteBuffer
import org.apache.spark.CometTaskMemoryManager
import org.apache.spark.sql.comet.CometMetricNode
+import org.apache.comet.parquet.CometFileKeyUnwrapper
+
class Native extends NativeBase {
// scalastyle:off
@@ -69,7 +71,8 @@ class Native extends NativeBase {
debug: Boolean,
explain: Boolean,
tracingEnabled: Boolean,
- maxTempDirectorySize: Long): Long
+ maxTempDirectorySize: Long,
+ keyUnwrapper: CometFileKeyUnwrapper): Long
// scalastyle:on
/**
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
index cbca7304d..950d0e9d3 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
@@ -46,12 +46,14 @@ import
org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanE
import org.apache.comet.DataTypeSupport.isComplexType
import org.apache.comet.objectstore.NativeConfig
import org.apache.comet.parquet.{CometParquetScan, Native, SupportsComet}
+import org.apache.comet.parquet.CometParquetUtils.{encryptionEnabled,
isEncryptionConfigSupported}
import org.apache.comet.shims.CometTypeShim
/**
* Spark physical optimizer rule for replacing Spark scans with Comet scans.
*/
case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with
CometTypeShim {
+
import CometScanRule._
private lazy val showTransformations =
CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get()
@@ -144,22 +146,14 @@ case class CometScanRule(session: SparkSession) extends
Rule[SparkPlan] with Com
return withInfos(scanExec, fallbackReasons.toSet)
}
- val encryptionEnabled: Boolean =
- conf.getConfString("parquet.crypto.factory.class", "").nonEmpty &&
- conf.getConfString("parquet.encryption.kms.client.class",
"").nonEmpty
-
var scanImpl = COMET_NATIVE_SCAN_IMPL.get()
+ val hadoopConf = scanExec.relation.sparkSession.sessionState
+ .newHadoopConfWithOptions(scanExec.relation.options)
+
// if scan is auto then pick the best available scan
if (scanImpl == SCAN_AUTO) {
- if (encryptionEnabled) {
- logInfo(
- s"Auto scan mode falling back to $SCAN_NATIVE_COMET because " +
- s"$SCAN_NATIVE_ICEBERG_COMPAT does not support reading
encrypted Parquet files")
- scanImpl = SCAN_NATIVE_COMET
- } else {
- scanImpl = selectScan(scanExec, r.partitionSchema)
- }
+ scanImpl = selectScan(scanExec, r.partitionSchema, hadoopConf)
}
if (scanImpl == SCAN_NATIVE_DATAFUSION && !COMET_EXEC_ENABLED.get()) {
@@ -206,10 +200,10 @@ case class CometScanRule(session: SparkSession) extends
Rule[SparkPlan] with Com
return withInfos(scanExec, fallbackReasons.toSet)
}
- if (scanImpl != CometConf.SCAN_NATIVE_COMET && encryptionEnabled) {
- fallbackReasons +=
- "Full native scan disabled because encryption is not supported"
- return withInfos(scanExec, fallbackReasons.toSet)
+ if (scanImpl != CometConf.SCAN_NATIVE_COMET &&
encryptionEnabled(hadoopConf)) {
+ if (!isEncryptionConfigSupported(hadoopConf)) {
+ return withInfos(scanExec, fallbackReasons.toSet)
+ }
}
val typeChecker = CometScanTypeChecker(scanImpl)
@@ -303,7 +297,10 @@ case class CometScanRule(session: SparkSession) extends
Rule[SparkPlan] with Com
}
}
- private def selectScan(scanExec: FileSourceScanExec, partitionSchema:
StructType): String = {
+ private def selectScan(
+ scanExec: FileSourceScanExec,
+ partitionSchema: StructType,
+ hadoopConf: Configuration): String = {
val fallbackReasons = new ListBuffer[String]()
@@ -313,10 +310,7 @@ case class CometScanRule(session: SparkSession) extends
Rule[SparkPlan] with Com
val filePath = scanExec.relation.inputFiles.headOption
if (filePath.exists(_.startsWith("s3a://"))) {
- validateObjectStoreConfig(
- filePath.get,
- session.sparkContext.hadoopConfiguration,
- fallbackReasons)
+ validateObjectStoreConfig(filePath.get, hadoopConf, fallbackReasons)
}
} else {
fallbackReasons += s"$SCAN_NATIVE_ICEBERG_COMPAT only supports local
filesystem and S3"
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 8fc7c2d63..43f8b7293 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -48,6 +48,7 @@ import org.apache.comet.{CometConf, ConfigEntry}
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo}
import org.apache.comet.expressions._
import org.apache.comet.objectstore.NativeConfig
+import org.apache.comet.parquet.CometParquetUtils
import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode =>
CometAggregateMode, BuildSide, JoinType, Operator}
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal,
optExprWithInfo, scalarFunctionExprToProto}
@@ -1161,6 +1162,9 @@ object QueryPlanSerde extends Logging with CometExprShim {
// Collect S3/cloud storage configurations
val hadoopConf = scan.relation.sparkSession.sessionState
.newHadoopConfWithOptions(scan.relation.options)
+
+
nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf))
+
firstPartition.foreach { partitionFile =>
val objectStoreOptions =
NativeConfig.extractObjectStoreOptions(hadoopConf,
partitionFile.pathUri)
@@ -1702,6 +1706,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
}
// scalastyle:off
+
/**
* Align w/ Arrow's
*
[[https://github.com/apache/arrow-rs/blob/55.2.0/arrow-ord/src/rank.rs#L30-L40
can_rank]] and
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
index 3dfd1f8d0..43a1e5b9a 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
@@ -103,7 +103,9 @@ class CometNativeShuffleWriter[K, V](
nativePlan,
nativeMetrics,
numParts,
- context.partitionId())
+ context.partitionId(),
+ broadcastedHadoopConfForEncryption = None,
+ encryptedFilePaths = Seq.empty)
while (cometIter.hasNext) {
cometIter.next()
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index a7cfacc47..de6892638 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -25,27 +25,29 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.TaskContext
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
Expression, NamedExpression, SortOrder}
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateMode}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide}
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
HashPartitioningLike, Partitioning, PartitioningCollection, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.comet.util.Utils
-import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec,
ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan,
UnaryExecNode}
-import org.apache.spark.sql.execution.PartitioningPreservingUnaryExecNode
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.io.ChunkedByteBuffer
import com.google.common.base.Objects
import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
+import org.apache.comet.parquet.CometParquetUtils
import org.apache.comet.serde.OperatorOuterClass.Operator
/**
@@ -114,7 +116,9 @@ object CometExec {
nativePlan,
CometMetricNode(Map.empty),
numParts,
- partitionIdx)
+ partitionIdx,
+ broadcastedHadoopConfForEncryption = None,
+ encryptedFilePaths = Seq.empty)
}
def getCometIterator(
@@ -123,7 +127,9 @@ object CometExec {
nativePlan: Operator,
nativeMetrics: CometMetricNode,
numParts: Int,
- partitionIdx: Int): CometExecIterator = {
+ partitionIdx: Int,
+ broadcastedHadoopConfForEncryption:
Option[Broadcast[SerializableConfiguration]],
+ encryptedFilePaths: Seq[String]): CometExecIterator = {
val outputStream = new ByteArrayOutputStream()
nativePlan.writeTo(outputStream)
outputStream.close()
@@ -135,7 +141,9 @@ object CometExec {
bytes,
nativeMetrics,
numParts,
- partitionIdx)
+ partitionIdx,
+ broadcastedHadoopConfForEncryption,
+ encryptedFilePaths)
}
/**
@@ -201,6 +209,39 @@ abstract class CometNativeExec extends CometExec {
// TODO: support native metrics for all operators.
val nativeMetrics = CometMetricNode.fromCometPlan(this)
+ // For each relation in a CometNativeScan generate a hadoopConf,
+ // for each file path in a relation associate with hadoopConf
+ val cometNativeScans: Seq[CometNativeScanExec] = this
+ .collectLeaves()
+ .filter(_.isInstanceOf[CometNativeScanExec])
+ .map(_.asInstanceOf[CometNativeScanExec])
+ assert(
+ cometNativeScans.size <= 1,
+ "We expect one native scan in a Comet plan since we will broadcast
one hadoopConf.")
+ // If this assumption changes in the future, you can look at the
commit history of #2447
+ // to see how there used to be a map of relations to broadcasted confs
in case multiple
+ // relations in a single plan. The example that came up was UNION. See
discussion at:
+ //
https://github.com/apache/datafusion-comet/pull/2447#discussion_r2406118264
+ val (broadcastedHadoopConfForEncryption, encryptedFilePaths) =
+ cometNativeScans.headOption.fold(
+ (None: Option[Broadcast[SerializableConfiguration]],
Seq.empty[String])) { scan =>
+ // This creates a hadoopConf that brings in any SQLConf
"spark.hadoop.*" configs and
+ // per-relation configs since different tables might have
different decryption
+ // properties.
+ val hadoopConf = scan.relation.sparkSession.sessionState
+ .newHadoopConfWithOptions(scan.relation.options)
+ val encryptionEnabled =
CometParquetUtils.encryptionEnabled(hadoopConf)
+ if (encryptionEnabled) {
+ // hadoopConf isn't serializable, so we have to do a broadcasted
config.
+ val broadcastedConf =
+ scan.relation.sparkSession.sparkContext
+ .broadcast(new SerializableConfiguration(hadoopConf))
+ (Some(broadcastedConf), scan.relation.inputFiles.toSeq)
+ } else {
+ (None, Seq.empty)
+ }
+ }
+
def createCometExecIter(
inputs: Seq[Iterator[ColumnarBatch]],
numParts: Int,
@@ -212,7 +253,9 @@ abstract class CometNativeExec extends CometExec {
serializedPlanCopy,
nativeMetrics,
numParts,
- partitionIndex)
+ partitionIndex,
+ broadcastedHadoopConfForEncryption,
+ encryptedFilePaths)
setSubqueries(it.id, this)
@@ -429,6 +472,7 @@ abstract class CometBinaryExec extends CometNativeExec with
BinaryExecNode
*/
case class SerializedPlan(plan: Option[Array[Byte]]) {
def isDefined: Boolean = plan.isDefined
+
def isEmpty: Boolean = plan.isEmpty
}
@@ -442,6 +486,7 @@ case class CometProjectExec(
extends CometUnaryExec
with PartitioningPreservingUnaryExecNode {
override def producedAttributes: AttributeSet = outputSet
+
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -474,6 +519,7 @@ case class CometFilterExec(
extends CometUnaryExec {
override def outputPartitioning: Partitioning = child.outputPartitioning
+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
@@ -551,7 +597,9 @@ case class CometLocalLimitExec(
extends CometUnaryExec {
override def output: Seq[Attribute] = child.output
+
override def outputPartitioning: Partitioning = child.outputPartitioning
+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
@@ -585,7 +633,9 @@ case class CometGlobalLimitExec(
extends CometUnaryExec {
override def output: Seq[Attribute] = child.output
+
override def outputPartitioning: Partitioning = child.outputPartitioning
+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
@@ -985,6 +1035,7 @@ case class CometScanWrapper(override val nativeOp:
Operator, override val origin
extends CometNativeExec
with LeafExecNode {
override val serializedPlanOpt: SerializedPlan = SerializedPlan(None)
+
override def stringArgs: Iterator[Any] = Iterator(originalPlan.output,
originalPlan)
}
diff --git
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala
index 6e6c62491..1cbe27be9 100644
---
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala
+++
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala
@@ -20,9 +20,14 @@
package org.apache.spark.sql.benchmark
import java.io.File
+import java.nio.charset.StandardCharsets
+import java.util.Base64
import scala.util.Random
+import org.apache.parquet.crypto.DecryptionPropertiesFactory
+import org.apache.parquet.crypto.keytools.{KeyToolkit,
PropertiesDrivenCryptoFactory}
+import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
import org.apache.spark.SparkConf
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
@@ -120,6 +125,41 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
spark.read.parquet(dir).createOrReplaceTempView("parquetV1Table")
}
+ protected def prepareEncryptedTable(
+ dir: File,
+ df: DataFrame,
+ partition: Option[String] = None): Unit = {
+ val testDf = if (partition.isDefined) {
+ df.write.partitionBy(partition.get)
+ } else {
+ df.write
+ }
+
+ saveAsEncryptedParquetV1Table(testDf, dir.getCanonicalPath + "/parquetV1")
+ }
+
+ protected def saveAsEncryptedParquetV1Table(df: DataFrameWriter[Row], dir:
String): Unit = {
+ val encoder = Base64.getEncoder
+ val footerKey =
+
encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8))
+ val key1 =
encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8))
+ val cryptoFactoryClass =
+ "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory"
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}") {
+ df.mode("overwrite")
+ .option("compression", "snappy")
+ .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
"key1: id")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(dir)
+ spark.read.parquet(dir).createOrReplaceTempView("parquetV1Table")
+ }
+ }
+
protected def makeDecimalDataFrame(
values: Int,
decimal: DecimalType,
diff --git
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
index 02b9ca5dc..a5db4f290 100644
---
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
+++
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
@@ -20,11 +20,16 @@
package org.apache.spark.sql.benchmark
import java.io.File
+import java.nio.charset.StandardCharsets
+import java.util.Base64
import scala.jdk.CollectionConverters._
import scala.util.Random
import org.apache.hadoop.fs.Path
+import org.apache.parquet.crypto.DecryptionPropertiesFactory
+import org.apache.parquet.crypto.keytools.KeyToolkit
+import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
import org.apache.spark.TestUtils
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.{DataFrame, SparkSession}
@@ -93,6 +98,94 @@ class CometReadBaseBenchmark extends CometBenchmarkBase {
}
}
+ def encryptedScanBenchmark(values: Int, dataType: DataType): Unit = {
+ // Benchmarks running through spark sql.
+ val sqlBenchmark =
+ new Benchmark(s"SQL Single ${dataType.sql} Encrypted Column Scan",
values, output = output)
+
+ val encoder = Base64.getEncoder
+ val footerKey =
+
encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8))
+ val key1 =
encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8))
+ val cryptoFactoryClass =
+ "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory"
+
+ withTempPath { dir =>
+ withTempTable("parquetV1Table") {
+ prepareEncryptedTable(
+ dir,
+ spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM $tbl"))
+
+ val query = dataType match {
+ case BooleanType => "sum(cast(id as bigint))"
+ case _ => "sum(id)"
+ }
+
+ sqlBenchmark.addCase("SQL Parquet - Spark") { _ =>
+ withSQLConf(
+ "spark.memory.offHeap.enabled" -> "true",
+ "spark.memory.offHeap.size" -> "10g",
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}") {
+ spark.sql(s"select $query from parquetV1Table").noop()
+ }
+ }
+
+ sqlBenchmark.addCase("SQL Parquet - Comet") { _ =>
+ withSQLConf(
+ "spark.memory.offHeap.enabled" -> "true",
+ "spark.memory.offHeap.size" -> "10g",
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET,
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}") {
+ spark.sql(s"select $query from parquetV1Table").noop()
+ }
+ }
+
+ sqlBenchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ =>
+ withSQLConf(
+ "spark.memory.offHeap.enabled" -> "true",
+ "spark.memory.offHeap.size" -> "10g",
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_DATAFUSION,
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}") {
+ spark.sql(s"select $query from parquetV1Table").noop()
+ }
+ }
+
+ sqlBenchmark.addCase("SQL Parquet - Comet Native Iceberg Compat") { _
=>
+ withSQLConf(
+ "spark.memory.offHeap.enabled" -> "true",
+ "spark.memory.offHeap.size" -> "10g",
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_ICEBERG_COMPAT,
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}") {
+ spark.sql(s"select $query from parquetV1Table").noop()
+ }
+ }
+
+ sqlBenchmark.run()
+ }
+ }
+ }
+
def decimalScanBenchmark(values: Int, precision: Int, scale: Int): Unit = {
val sqlBenchmark = new Benchmark(
s"SQL Single Decimal(precision: $precision, scale: $scale) Column Scan",
@@ -552,13 +645,20 @@ class CometReadBaseBenchmark extends CometBenchmarkBase {
}
}
- runBenchmarkWithTable("SQL Single Numeric Column Scan", 1024 * 1024 * 15)
{ v =>
+ runBenchmarkWithTable("SQL Single Numeric Column Scan", 1024 * 1024 * 128)
{ v =>
Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
DoubleType)
.foreach { dataType =>
numericScanBenchmark(v, dataType)
}
}
+ runBenchmarkWithTable("SQL Single Numeric Encrypted Column Scan", 1024 *
1024 * 128) { v =>
+ Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
DoubleType)
+ .foreach { dataType =>
+ encryptedScanBenchmark(v, dataType)
+ }
+ }
+
runBenchmark("SQL Decimal Column Scan") {
withTempTable(tbl) {
import spark.implicits._
@@ -639,6 +739,7 @@ object CometReadHdfsBenchmark extends
CometReadBaseBenchmark with WithHdfsCluste
finally getFileSystem.delete(tempHdfsPath, true)
}
}
+
override protected def prepareTable(
dir: File,
df: DataFrame,
diff --git
a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala
b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala
index 8d2c3db72..cff21ecec 100644
---
a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala
+++
b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala
@@ -19,8 +19,7 @@
package org.apache.spark.sql.comet
-import java.io.File
-import java.io.RandomAccessFile
+import java.io.{File, RandomAccessFile}
import java.nio.charset.StandardCharsets
import java.util.Base64
@@ -29,12 +28,16 @@ import org.scalactic.source.Position
import org.scalatest.Tag
import org.scalatestplus.junit.JUnitRunner
+import org.apache.parquet.crypto.DecryptionPropertiesFactory
+import org.apache.parquet.crypto.keytools.{KeyToolkit,
PropertiesDrivenCryptoFactory}
+import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{CometTestBase, SQLContext}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.comet.{CometConf, IntegrationTestSuite}
+import org.apache.comet.CometConf.{SCAN_NATIVE_COMET, SCAN_NATIVE_DATAFUSION,
SCAN_NATIVE_ICEBERG_COMPAT}
/**
* A integration test suite that tests parquet modular encryption usage.
@@ -47,90 +50,399 @@ class ParquetEncryptionITCase extends CometTestBase with
SQLTestUtils {
encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8))
private val key1 =
encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8))
private val key2 =
encoder.encodeToString("1234567890123451".getBytes(StandardCharsets.UTF_8))
+ private val cryptoFactoryClass =
+ "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory"
test("SPARK-34990: Write and read an encrypted parquet") {
- assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() !=
CometConf.SCAN_NATIVE_DATAFUSION)
- assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() !=
CometConf.SCAN_NATIVE_ICEBERG_COMPAT)
import testImplicits._
-
Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach
{
- factoryClass =>
- withTempDir { dir =>
- withSQLConf(
- "parquet.crypto.factory.class" -> factoryClass,
- "parquet.encryption.kms.client.class" ->
- "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
- "parquet.encryption.key.list" ->
- s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
-
- // Make sure encryption works with multiple Parquet files
- val inputDF = spark
- .range(0, 2000)
- .map(i => (i, i.toString, i.toFloat))
- .repartition(10)
- .toDF("a", "b", "c")
- val parquetDir = new File(dir, "parquet").getCanonicalPath
- inputDF.write
- .option("parquet.encryption.column.keys", "key1: a, b; key2: c")
- .option("parquet.encryption.footer.key", "footerKey")
- .parquet(parquetDir)
-
- verifyParquetEncrypted(parquetDir)
-
- val parquetDF = spark.read.parquet(parquetDir)
- assert(parquetDF.inputFiles.nonEmpty)
- val readDataset = parquetDF.select("a", "b", "c")
-
- if (CometConf.COMET_ENABLED.get(conf)) {
- checkSparkAnswerAndOperator(readDataset)
- } else {
- checkAnswer(readDataset, inputDF)
- }
- }
+ withTempDir { dir =>
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+ // Make sure encryption works with multiple Parquet files
+ val inputDF = spark
+ .range(0, 2000)
+ .map(i => (i, i.toString, i.toFloat))
+ .repartition(10)
+ .toDF("a", "b", "c")
+ val parquetDir = new File(dir, "parquet").getCanonicalPath
+ inputDF.write
+ .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
"key1: a, b; key2: c")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir)
+
+ verifyParquetEncrypted(parquetDir)
+
+ val parquetDF = spark.read.parquet(parquetDir)
+ assert(parquetDF.inputFiles.nonEmpty)
+ val readDataset = parquetDF.select("a", "b", "c")
+
+ if (CometConf.COMET_ENABLED.get(conf)) {
+ checkSparkAnswerAndOperator(readDataset)
+ } else {
+ checkAnswer(readDataset, inputDF)
}
+ }
}
}
test("SPARK-37117: Can't read files in Parquet encryption external key
material mode") {
- assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() !=
CometConf.SCAN_NATIVE_DATAFUSION)
- assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() !=
CometConf.SCAN_NATIVE_ICEBERG_COMPAT)
import testImplicits._
-
Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach
{
- factoryClass =>
- withTempDir { dir =>
- withSQLConf(
- "parquet.crypto.factory.class" -> factoryClass,
- "parquet.encryption.kms.client.class" ->
- "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
- "parquet.encryption.key.material.store.internally" -> "false",
- "parquet.encryption.key.list" ->
- s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
-
- val inputDF = spark
- .range(0, 2000)
- .map(i => (i, i.toString, i.toFloat))
- .repartition(10)
- .toDF("a", "b", "c")
- val parquetDir = new File(dir, "parquet").getCanonicalPath
- inputDF.write
- .option("parquet.encryption.column.keys", "key1: a, b; key2: c")
- .option("parquet.encryption.footer.key", "footerKey")
- .parquet(parquetDir)
-
- val parquetDF = spark.read.parquet(parquetDir)
- assert(parquetDF.inputFiles.nonEmpty)
- val readDataset = parquetDF.select("a", "b", "c")
-
- if (CometConf.COMET_ENABLED.get(conf)) {
- checkSparkAnswerAndOperator(readDataset)
- } else {
- checkAnswer(readDataset, inputDF)
- }
- }
+ withTempDir { dir =>
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ KeyToolkit.KEY_MATERIAL_INTERNAL_PROPERTY_NAME -> "false", // default
is true
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+ val inputDF = spark
+ .range(0, 2000)
+ .map(i => (i, i.toString, i.toFloat))
+ .repartition(10)
+ .toDF("a", "b", "c")
+ val parquetDir = new File(dir, "parquet").getCanonicalPath
+ inputDF.write
+ .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
"key1: a, b; key2: c")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir)
+
+ verifyParquetEncrypted(parquetDir)
+
+ val parquetDF = spark.read.parquet(parquetDir)
+ assert(parquetDF.inputFiles.nonEmpty)
+ val readDataset = parquetDF.select("a", "b", "c")
+
+ if (CometConf.COMET_ENABLED.get(conf)) {
+ checkSparkAnswerAndOperator(readDataset)
+ } else {
+ checkAnswer(readDataset, inputDF)
+ }
+ }
+ }
+ }
+
+ test("SPARK-42114: Test of uniform parquet encryption") {
+
+ import testImplicits._
+
+ withTempDir { dir =>
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"key1: ${key1}") {
+
+ val inputDF = spark
+ .range(0, 2000)
+ .map(i => (i, i.toString, i.toFloat))
+ .repartition(10)
+ .toDF("a", "b", "c")
+ val parquetDir = new File(dir, "parquet").getCanonicalPath
+ inputDF.write
+ .option("parquet.encryption.uniform.key", "key1")
+ .parquet(parquetDir)
+
+ verifyParquetEncrypted(parquetDir)
+
+ val parquetDF = spark.read.parquet(parquetDir)
+ assert(parquetDF.inputFiles.nonEmpty)
+ val readDataset = parquetDF.select("a", "b", "c")
+
+ if (CometConf.COMET_ENABLED.get(conf)) {
+ checkSparkAnswerAndOperator(readDataset)
+ } else {
+ checkAnswer(readDataset, inputDF)
+ }
+ }
+ }
+ }
+
+ test("Plain text footer mode") {
+ import testImplicits._
+
+ withTempDir { dir =>
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ PropertiesDrivenCryptoFactory.PLAINTEXT_FOOTER_PROPERTY_NAME ->
"true", // default is false
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+ val inputDF = spark
+ .range(0, 1000)
+ .map(i => (i, i.toString, i.toFloat))
+ .repartition(5)
+ .toDF("a", "b", "c")
+ val parquetDir = new File(dir, "parquet").getCanonicalPath
+ inputDF.write
+ .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
"key1: a, b; key2: c")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir)
+
+ verifyParquetPlaintextFooter(parquetDir)
+
+ val parquetDF = spark.read.parquet(parquetDir)
+ assert(parquetDF.inputFiles.nonEmpty)
+ val readDataset = parquetDF.select("a", "b", "c")
+
+ if (CometConf.COMET_ENABLED.get(conf)) {
+ checkSparkAnswerAndOperator(readDataset)
+ } else {
+ checkAnswer(readDataset, inputDF)
+ }
+ }
+ }
+ }
+
+ test("Change encryption algorithm") {
+ import testImplicits._
+
+ withTempDir { dir =>
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ // default is AES_GCM_V1
+ PropertiesDrivenCryptoFactory.ENCRYPTION_ALGORITHM_PROPERTY_NAME ->
"AES_GCM_CTR_V1",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+ val inputDF = spark
+ .range(0, 1000)
+ .map(i => (i, i.toString, i.toFloat))
+ .repartition(5)
+ .toDF("a", "b", "c")
+ val parquetDir = new File(dir, "parquet").getCanonicalPath
+ inputDF.write
+ .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
"key1: a, b; key2: c")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir)
+
+ verifyParquetEncrypted(parquetDir)
+
+ val parquetDF = spark.read.parquet(parquetDir)
+ assert(parquetDF.inputFiles.nonEmpty)
+ val readDataset = parquetDF.select("a", "b", "c")
+
+ // native_datafusion and native_iceberg_compat fall back due to
Arrow-rs
+ //
https://github.com/apache/arrow-rs/blob/da9829728e2a9dffb8d4f47ffe7b103793851724/parquet/src/file/metadata/parser.rs#L494
+ if (CometConf.COMET_ENABLED.get(conf) &&
CometConf.COMET_NATIVE_SCAN_IMPL.get(
+ conf) == SCAN_NATIVE_COMET) {
+ checkSparkAnswerAndOperator(readDataset)
+ } else {
+ checkAnswer(readDataset, inputDF)
+ }
+ }
+ }
+ }
+
+ test("Test double wrapping disabled") {
+ import testImplicits._
+
+ withTempDir { dir =>
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ KeyToolkit.DOUBLE_WRAPPING_PROPERTY_NAME -> "false", // default is true
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+ val inputDF = spark
+ .range(0, 1000)
+ .map(i => (i, i.toString, i.toFloat))
+ .repartition(5)
+ .toDF("a", "b", "c")
+ val parquetDir = new File(dir, "parquet").getCanonicalPath
+ inputDF.write
+ .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
"key1: a, b; key2: c")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir)
+
+ verifyParquetEncrypted(parquetDir)
+
+ val parquetDF = spark.read.parquet(parquetDir)
+ assert(parquetDF.inputFiles.nonEmpty)
+ val readDataset = parquetDF.select("a", "b", "c")
+
+ if (CometConf.COMET_ENABLED.get(conf)) {
+ checkSparkAnswerAndOperator(readDataset)
+ } else {
+ checkAnswer(readDataset, inputDF)
+ }
+ }
+ }
+ }
+
+ test("Join between files with different encryption keys") {
+ import testImplicits._
+
+ withTempDir { dir =>
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+ // Write first file
+ val inputDF1 = spark
+ .range(0, 100)
+ .map(i => (i, s"file1_${i}", i.toFloat))
+ .toDF("id", "name", "value")
+ val parquetDir1 = new File(dir, "parquet1").getCanonicalPath
+ inputDF1.write
+ .option(
+ PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
+ "key1: id, name, value")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir1)
+
+ // Write second file using different column key
+ val inputDF2 = spark
+ .range(0, 100)
+ .map(i => (i, s"file2_${i}", (i * 2).toFloat))
+ .toDF("id", "description", "score")
+ val parquetDir2 = new File(dir, "parquet2").getCanonicalPath
+ inputDF2.write
+ .option(
+ PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
+ "key2: id, description, score")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir2)
+
+ verifyParquetEncrypted(parquetDir1)
+ verifyParquetEncrypted(parquetDir2)
+
+ // Now perform a join between the two files with different encryption
keys
+ // This tests that hadoopConf properties propagate correctly to each
scan
+ val parquetDF1 = spark.read.parquet(parquetDir1).alias("f1")
+ val parquetDF2 = spark.read.parquet(parquetDir2).alias("f2")
+
+ val joinedDF = parquetDF1
+ .join(parquetDF2, parquetDF1("id") === parquetDF2("id"), "inner")
+ .select(
+ parquetDF1("id"),
+ parquetDF1("name"),
+ parquetDF2("description"),
+ parquetDF2("score"))
+
+ if (CometConf.COMET_ENABLED.get(conf)) {
+ checkSparkAnswerAndOperator(joinedDF)
+ } else {
+ checkSparkAnswer(joinedDF)
+ }
+ }
+ }
+ }
+
+ // Union ends up with two scans in the same plan, so this ensures that Comet
can distinguish
+ // between the hadoopConfs for each relation
+ test("Union between files with different encryption keys") {
+ import testImplicits._
+
+ withTempDir { dir =>
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+ // Write first file with key1
+ val inputDF1 = spark
+ .range(0, 100)
+ .map(i => (i, s"file1_${i}", i.toFloat))
+ .toDF("id", "name", "value")
+ val parquetDir1 = new File(dir, "parquet1").getCanonicalPath
+ inputDF1.write
+ .option(
+ PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
+ "key1: id, name, value")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir1)
+
+ // Write second file with key2 - same schema, different encryption key
+ val inputDF2 = spark
+ .range(100, 200)
+ .map(i => (i, s"file2_${i}", i.toFloat))
+ .toDF("id", "name", "value")
+ val parquetDir2 = new File(dir, "parquet2").getCanonicalPath
+ inputDF2.write
+ .option(
+ PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
+ "key2: id, name, value")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir2)
+
+ verifyParquetEncrypted(parquetDir1)
+ verifyParquetEncrypted(parquetDir2)
+
+ val parquetDF1 = spark.read.parquet(parquetDir1)
+ val parquetDF2 = spark.read.parquet(parquetDir2)
+
+ val unionDF = parquetDF1.union(parquetDF2)
+
+ if (CometConf.COMET_ENABLED.get(conf)) {
+ checkSparkAnswerAndOperator(unionDF)
+ } else {
+ checkSparkAnswer(unionDF)
}
+ }
+ }
+ }
+
+ test("Test different key lengths") {
+ import testImplicits._
+
+ withTempDir { dir =>
+ withSQLConf(
+ DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME ->
cryptoFactoryClass,
+ KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
+ "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
+ KeyToolkit.DATA_KEY_LENGTH_PROPERTY_NAME -> "256",
+ KeyToolkit.KEK_LENGTH_PROPERTY_NAME -> "256",
+ InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
+ s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
+
+ val inputDF = spark
+ .range(0, 1000)
+ .map(i => (i, i.toString, i.toFloat))
+ .repartition(5)
+ .toDF("a", "b", "c")
+ val parquetDir = new File(dir, "parquet").getCanonicalPath
+ inputDF.write
+ .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME,
"key1: a, b; key2: c")
+ .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME,
"footerKey")
+ .parquet(parquetDir)
+
+ verifyParquetEncrypted(parquetDir)
+
+ val parquetDF = spark.read.parquet(parquetDir)
+ assert(parquetDF.inputFiles.nonEmpty)
+ val readDataset = parquetDF.select("a", "b", "c")
+
+ // native_datafusion and native_iceberg_compat fall back due to
Arrow-rs not
+ // supporting other key lengths
+ if (CometConf.COMET_ENABLED.get(conf) &&
CometConf.COMET_NATIVE_SCAN_IMPL.get(
+ conf) == SCAN_NATIVE_COMET) {
+ checkSparkAnswerAndOperator(readDataset)
+ } else {
+ checkAnswer(readDataset, inputDF)
+ }
+ }
}
}
@@ -146,13 +458,29 @@ class ParquetEncryptionITCase extends CometTestBase with
SQLTestUtils {
override protected def test(testName: String, testTags: Tag*)(testFun: =>
Any)(implicit
pos: Position): Unit = {
+
Seq("true", "false").foreach { cometEnabled =>
- super.test(testName + s" Comet($cometEnabled)", testTags: _*) {
- withSQLConf(
- CometConf.COMET_ENABLED.key -> cometEnabled,
- CometConf.COMET_EXEC_ENABLED.key -> "true",
- SQLConf.ANSI_ENABLED.key -> "true") {
- testFun
+ if (cometEnabled == "true") {
+ Seq(SCAN_NATIVE_COMET, SCAN_NATIVE_DATAFUSION,
SCAN_NATIVE_ICEBERG_COMPAT).foreach {
+ scanImpl =>
+ super.test(testName + s" Comet($cometEnabled)" + s"
Scan($scanImpl)", testTags: _*) {
+ withSQLConf(
+ CometConf.COMET_ENABLED.key -> cometEnabled,
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ SQLConf.ANSI_ENABLED.key -> "false",
+ CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanImpl) {
+ testFun
+ }
+ }
+ }
+ } else {
+ super.test(testName + s" Comet($cometEnabled)", testTags: _*) {
+ withSQLConf(
+ CometConf.COMET_ENABLED.key -> cometEnabled,
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ SQLConf.ANSI_ENABLED.key -> "false") {
+ testFun
+ }
}
}
}
@@ -164,7 +492,9 @@ class ParquetEncryptionITCase extends CometTestBase with
SQLTestUtils {
}
private var _spark: SparkSessionType = _
+
protected implicit override def spark: SparkSessionType = _spark
+
protected implicit override def sqlContext: SQLContext = _spark.sqlContext
/**
@@ -182,12 +512,50 @@ class ParquetEncryptionITCase extends CometTestBase with
SQLTestUtils {
val byteArray = new Array[Byte](magicStringLength)
val randomAccessFile = new RandomAccessFile(parquetFile, "r")
try {
+ // Check first 4 bytes
+ randomAccessFile.read(byteArray, 0, magicStringLength)
+ val firstMagicString = new String(byteArray, StandardCharsets.UTF_8)
+ assert(magicString == firstMagicString)
+
+ // Check last 4 bytes
+ randomAccessFile.seek(randomAccessFile.length() - magicStringLength)
+ randomAccessFile.read(byteArray, 0, magicStringLength)
+ val lastMagicString = new String(byteArray, StandardCharsets.UTF_8)
+ assert(magicString == lastMagicString)
+ } finally {
+ randomAccessFile.close()
+ }
+ }
+ }
+
+ /**
+ * Verify that the directory contains an encrypted parquet in plaintext
footer mode by means of
+ * checking for all the parquet part files in the parquet directory that
their magic string is
+ * PAR1, as defined in the spec:
+ *
https://github.com/apache/parquet-format/blob/master/Encryption.md#55-plaintext-footer-mode
+ */
+ private def verifyParquetPlaintextFooter(parquetDir: String): Unit = {
+ val parquetPartitionFiles = getListOfParquetFiles(new File(parquetDir))
+ assert(parquetPartitionFiles.size >= 1)
+ parquetPartitionFiles.foreach { parquetFile =>
+ val magicString = "PAR1"
+ val magicStringLength = magicString.length()
+ val byteArray = new Array[Byte](magicStringLength)
+ val randomAccessFile = new RandomAccessFile(parquetFile, "r")
+ try {
+ // Check first 4 bytes
+ randomAccessFile.read(byteArray, 0, magicStringLength)
+ val firstMagicString = new String(byteArray, StandardCharsets.UTF_8)
+ assert(magicString == firstMagicString)
+
+ // Check last 4 bytes
+ randomAccessFile.seek(randomAccessFile.length() - magicStringLength)
randomAccessFile.read(byteArray, 0, magicStringLength)
+ val lastMagicString = new String(byteArray, StandardCharsets.UTF_8)
+ assert(magicString == lastMagicString)
} finally {
randomAccessFile.close()
}
- val stringRead = new String(byteArray, StandardCharsets.UTF_8)
- assert(magicString == stringRead)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]