This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new 1d217d49 Add support for custom credential loader for S3 FileIO (#1528)
1d217d49 is described below

commit 1d217d4943bcce0ec97bb135fba10cda14da6b79
Author: Phillip LeBlanc <[email protected]>
AuthorDate: Mon Jul 21 12:38:02 2025 +0900

    Add support for custom credential loader for S3 FileIO (#1528)
    
    ## Which issue does this PR close?
    
    - Closes #1527
    
    ## What changes are included in this PR?
    
    Adds the ability to provide custom extensions to the `FileIOBuilder`.
    Currently the only supported extension is `CustomAwsCredentialLoader`
    which is a newtype around
    
[`AwsCredentialLoad`](https://docs.rs/reqsign/0.16.3/reqsign/trait.AwsCredentialLoad.html),
    which is what OpenDAL expects.
    
    I've added extensions to the `RestCatalog` as well, and when its
    constructing `FileIO` for table operations, passes along any defined
    extensions into the `FileIOBuilder`, which then get passed into the
    underlying OpenDAL constructors.
    
    ## Are these changes tested?
    
    Yes, tests added in `crates/iceberg/tests/file_io_s3_test.rs` that
    verify the extension is working.
---
 Cargo.lock                              |   1 +
 crates/catalog/rest/src/catalog.rs      |  17 +++-
 crates/iceberg/Cargo.toml               |   3 +-
 crates/iceberg/src/io/file_io.rs        |  56 ++++++++++-
 crates/iceberg/src/io/storage.rs        |  11 ++-
 crates/iceberg/src/io/storage_s3.rs     |  50 +++++++++-
 crates/iceberg/tests/file_io_s3_test.rs | 168 ++++++++++++++++++++++++++++++--
 7 files changed, 291 insertions(+), 15 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 0b733163..072928ab 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3535,6 +3535,7 @@ dependencies = [
  "pretty_assertions",
  "rand 0.8.5",
  "regex",
+ "reqsign",
  "reqwest",
  "roaring",
  "rust_decimal",
diff --git a/crates/catalog/rest/src/catalog.rs 
b/crates/catalog/rest/src/catalog.rs
index 1bca0374..5c9e6e15 100644
--- a/crates/catalog/rest/src/catalog.rs
+++ b/crates/catalog/rest/src/catalog.rs
@@ -17,11 +17,12 @@
 
 //! This module contains the iceberg REST catalog implementation.
 
+use std::any::Any;
 use std::collections::HashMap;
 use std::str::FromStr;
 
 use async_trait::async_trait;
-use iceberg::io::FileIO;
+use iceberg::io::{self, FileIO};
 use iceberg::table::Table;
 use iceberg::{
     Catalog, Error, ErrorKind, Namespace, NamespaceIdent, Result, TableCommit, 
TableCreation,
@@ -240,6 +241,8 @@ pub struct RestCatalog {
     /// It's could be different from the config fetched from the server and 
used at runtime.
     user_config: RestCatalogConfig,
     ctx: OnceCell<RestContext>,
+    /// Extensions for the FileIOBuilder.
+    file_io_extensions: io::Extensions,
 }
 
 impl RestCatalog {
@@ -248,9 +251,16 @@ impl RestCatalog {
         Self {
             user_config: config,
             ctx: OnceCell::new(),
+            file_io_extensions: io::Extensions::default(),
         }
     }
 
+    /// Add an extension to the file IO builder.
+    pub fn with_file_io_extension<T: Any + Send + Sync>(mut self, ext: T) -> 
Self {
+        self.file_io_extensions.add(ext);
+        self
+    }
+
     /// Gets the [`RestContext`] from the catalog.
     async fn context(&self) -> Result<&RestContext> {
         self.ctx
@@ -307,7 +317,10 @@ impl RestCatalog {
         };
 
         let file_io = match warehouse_path.or(metadata_location) {
-            Some(url) => FileIO::from_path(url)?.with_props(props).build()?,
+            Some(url) => FileIO::from_path(url)?
+                .with_props(props)
+                .with_extensions(self.file_io_extensions.clone())
+                .build()?,
             None => {
                 return Err(Error::new(
                     ErrorKind::Unexpected,
diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml
index d6e00e4e..fe4fdf73 100644
--- a/crates/iceberg/Cargo.toml
+++ b/crates/iceberg/Cargo.toml
@@ -37,7 +37,7 @@ storage-fs = ["opendal/services-fs"]
 storage-gcs = ["opendal/services-gcs"]
 storage-memory = ["opendal/services-memory"]
 storage-oss = ["opendal/services-oss"]
-storage-s3 = ["opendal/services-s3"]
+storage-s3 = ["opendal/services-s3", "reqsign"]
 
 async-std = ["dep:async-std"]
 tokio = ["tokio/rt-multi-thread"]
@@ -76,6 +76,7 @@ ordered-float = { workspace = true }
 parquet = { workspace = true, features = ["async"] }
 rand = { workspace = true }
 reqwest = { workspace = true }
+reqsign = { version = "0.16.3", optional = true, default-features = false }
 roaring = { workspace = true }
 rust_decimal = { workspace = true }
 serde = { workspace = true }
diff --git a/crates/iceberg/src/io/file_io.rs b/crates/iceberg/src/io/file_io.rs
index 389397ec..087e98ce 100644
--- a/crates/iceberg/src/io/file_io.rs
+++ b/crates/iceberg/src/io/file_io.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::any::{Any, TypeId};
 use std::collections::HashMap;
 use std::ops::Range;
 use std::sync::Arc;
@@ -167,6 +168,31 @@ impl FileIO {
     }
 }
 
+/// Container for storing type-safe extensions used to configure underlying 
FileIO behavior.
+#[derive(Clone, Debug, Default)]
+pub struct Extensions(HashMap<TypeId, Arc<dyn Any + Send + Sync>>);
+
+impl Extensions {
+    /// Add an extension.
+    pub fn add<T: Any + Send + Sync>(&mut self, ext: T) {
+        self.0.insert(TypeId::of::<T>(), Arc::new(ext));
+    }
+
+    /// Extends the current set of extensions with another set of extensions.
+    pub fn extend(&mut self, extensions: Extensions) {
+        self.0.extend(extensions.0);
+    }
+
+    /// Fetch an extension.
+    pub fn get<T>(&self) -> Option<Arc<T>>
+    where T: 'static + Send + Sync + Clone {
+        let type_id = TypeId::of::<T>();
+        self.0
+            .get(&type_id)
+            .and_then(|arc_any| Arc::clone(arc_any).downcast::<T>().ok())
+    }
+}
+
 /// Builder for [`FileIO`].
 #[derive(Clone, Debug)]
 pub struct FileIOBuilder {
@@ -176,6 +202,8 @@ pub struct FileIOBuilder {
     scheme_str: Option<String>,
     /// Arguments for operator.
     props: HashMap<String, String>,
+    /// Optional extensions to configure the underlying FileIO behavior.
+    extensions: Extensions,
 }
 
 impl FileIOBuilder {
@@ -185,6 +213,7 @@ impl FileIOBuilder {
         Self {
             scheme_str: Some(scheme_str.to_string()),
             props: HashMap::default(),
+            extensions: Extensions::default(),
         }
     }
 
@@ -193,14 +222,19 @@ impl FileIOBuilder {
         Self {
             scheme_str: None,
             props: HashMap::default(),
+            extensions: Extensions::default(),
         }
     }
 
     /// Fetch the scheme string.
     ///
     /// The scheme_str will be empty if it's None.
-    pub fn into_parts(self) -> (String, HashMap<String, String>) {
-        (self.scheme_str.unwrap_or_default(), self.props)
+    pub fn into_parts(self) -> (String, HashMap<String, String>, Extensions) {
+        (
+            self.scheme_str.unwrap_or_default(),
+            self.props,
+            self.extensions,
+        )
     }
 
     /// Add argument for operator.
@@ -219,6 +253,24 @@ impl FileIOBuilder {
         self
     }
 
+    /// Add an extension to the file IO builder.
+    pub fn with_extension<T: Any + Send + Sync>(mut self, ext: T) -> Self {
+        self.extensions.add(ext);
+        self
+    }
+
+    /// Adds multiple extensions to the file IO builder.
+    pub fn with_extensions(mut self, extensions: Extensions) -> Self {
+        self.extensions.extend(extensions);
+        self
+    }
+
+    /// Fetch an extension from the file IO builder.
+    pub fn extension<T>(&self) -> Option<Arc<T>>
+    where T: 'static + Send + Sync + Clone {
+        self.extensions.get::<T>()
+    }
+
     /// Builds [`FileIO`].
     pub fn build(self) -> Result<FileIO> {
         let storage = Storage::build(self.clone())?;
diff --git a/crates/iceberg/src/io/storage.rs b/crates/iceberg/src/io/storage.rs
index a847977e..3de4f10d 100644
--- a/crates/iceberg/src/io/storage.rs
+++ b/crates/iceberg/src/io/storage.rs
@@ -31,6 +31,8 @@ use opendal::{Operator, Scheme};
 #[cfg(feature = "storage-azdls")]
 use super::AzureStorageScheme;
 use super::FileIOBuilder;
+#[cfg(feature = "storage-s3")]
+use crate::io::CustomAwsCredentialLoader;
 use crate::{Error, ErrorKind};
 
 /// The storage carries all supported storage services in iceberg
@@ -47,6 +49,7 @@ pub(crate) enum Storage {
         /// Storing the scheme string here to return the correct path.
         configured_scheme: String,
         config: Arc<S3Config>,
+        customized_credential_load: Option<CustomAwsCredentialLoader>,
     },
     #[cfg(feature = "storage-gcs")]
     Gcs { config: Arc<GcsConfig> },
@@ -67,7 +70,7 @@ pub(crate) enum Storage {
 impl Storage {
     /// Convert iceberg config to opendal config.
     pub(crate) fn build(file_io_builder: FileIOBuilder) -> crate::Result<Self> 
{
-        let (scheme_str, props) = file_io_builder.into_parts();
+        let (scheme_str, props, extensions) = file_io_builder.into_parts();
         let scheme = Self::parse_scheme(&scheme_str)?;
 
         match scheme {
@@ -79,6 +82,9 @@ impl Storage {
             Scheme::S3 => Ok(Self::S3 {
                 configured_scheme: scheme_str,
                 config: super::s3_config_parse(props)?.into(),
+                customized_credential_load: extensions
+                    .get::<CustomAwsCredentialLoader>()
+                    .map(Arc::unwrap_or_clone),
             }),
             #[cfg(feature = "storage-gcs")]
             Scheme::Gcs => Ok(Self::Gcs {
@@ -144,8 +150,9 @@ impl Storage {
             Storage::S3 {
                 configured_scheme,
                 config,
+                customized_credential_load,
             } => {
-                let op = super::s3_config_build(config, path)?;
+                let op = super::s3_config_build(config, 
customized_credential_load, path)?;
                 let op_info = op.info();
 
                 // Check prefix of s3 path.
diff --git a/crates/iceberg/src/io/storage_s3.rs 
b/crates/iceberg/src/io/storage_s3.rs
index 8396888c..f2408331 100644
--- a/crates/iceberg/src/io/storage_s3.rs
+++ b/crates/iceberg/src/io/storage_s3.rs
@@ -16,9 +16,13 @@
 // under the License.
 
 use std::collections::HashMap;
+use std::sync::Arc;
 
+use async_trait::async_trait;
 use opendal::services::S3Config;
 use opendal::{Configurator, Operator};
+pub use reqsign::{AwsCredential, AwsCredentialLoad};
+use reqwest::Client;
 use url::Url;
 
 use crate::io::is_truthy;
@@ -151,7 +155,11 @@ pub(crate) fn s3_config_parse(mut m: HashMap<String, 
String>) -> Result<S3Config
 }
 
 /// Build new opendal operator from give path.
-pub(crate) fn s3_config_build(cfg: &S3Config, path: &str) -> Result<Operator> {
+pub(crate) fn s3_config_build(
+    cfg: &S3Config,
+    customized_credential_load: &Option<CustomAwsCredentialLoader>,
+    path: &str,
+) -> Result<Operator> {
     let url = Url::parse(path)?;
     let bucket = url.host_str().ok_or_else(|| {
         Error::new(
@@ -160,11 +168,49 @@ pub(crate) fn s3_config_build(cfg: &S3Config, path: &str) 
-> Result<Operator> {
         )
     })?;
 
-    let builder = cfg
+    let mut builder = cfg
         .clone()
         .into_builder()
         // Set bucket name.
         .bucket(bucket);
 
+    if let Some(customized_credential_load) = customized_credential_load {
+        builder = builder
+            
.customized_credential_load(customized_credential_load.clone().into_opendal_loader());
+    }
+
     Ok(Operator::new(builder)?.finish())
 }
+
+/// Custom AWS credential loader.
+/// This can be used to load credentials from a custom source, such as the AWS 
SDK.
+///
+/// This should be set as an extension on `FileIOBuilder`.
+#[derive(Clone)]
+pub struct CustomAwsCredentialLoader(Arc<dyn AwsCredentialLoad>);
+
+impl std::fmt::Debug for CustomAwsCredentialLoader {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("CustomAwsCredentialLoader")
+            .finish_non_exhaustive()
+    }
+}
+
+impl CustomAwsCredentialLoader {
+    /// Create a new custom AWS credential loader.
+    pub fn new(loader: Arc<dyn AwsCredentialLoad>) -> Self {
+        Self(loader)
+    }
+
+    /// Convert this loader into an opendal compatible loader for customized 
AWS credentials.
+    pub fn into_opendal_loader(self) -> Box<dyn AwsCredentialLoad> {
+        Box::new(self)
+    }
+}
+
+#[async_trait]
+impl AwsCredentialLoad for CustomAwsCredentialLoader {
+    async fn load_credential(&self, client: Client) -> 
anyhow::Result<Option<AwsCredential>> {
+        self.0.load_credential(client).await
+    }
+}
diff --git a/crates/iceberg/tests/file_io_s3_test.rs 
b/crates/iceberg/tests/file_io_s3_test.rs
index eab6853b..b7c484de 100644
--- a/crates/iceberg/tests/file_io_s3_test.rs
+++ b/crates/iceberg/tests/file_io_s3_test.rs
@@ -18,15 +18,19 @@
 //! Integration tests for FileIO S3.
 #[cfg(all(test, feature = "storage-s3"))]
 mod tests {
-    use std::net::SocketAddr;
-    use std::sync::RwLock;
+    use std::net::{IpAddr, SocketAddr};
+    use std::sync::{Arc, RwLock};
 
+    use async_trait::async_trait;
     use ctor::{ctor, dtor};
     use iceberg::io::{
-        FileIO, FileIOBuilder, S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, 
S3_SECRET_ACCESS_KEY,
+        CustomAwsCredentialLoader, FileIO, FileIOBuilder, S3_ACCESS_KEY_ID, 
S3_ENDPOINT, S3_REGION,
+        S3_SECRET_ACCESS_KEY,
     };
     use iceberg_test_utils::docker::DockerCompose;
     use iceberg_test_utils::{normalize_test_name, set_up};
+    use reqsign::{AwsCredential, AwsCredentialLoad};
+    use reqwest::Client;
 
     const MINIO_PORT: u16 = 9000;
     static DOCKER_COMPOSE_ENV: RwLock<Option<DockerCompose>> = 
RwLock::new(None);
@@ -51,9 +55,7 @@ mod tests {
     async fn get_file_io() -> FileIO {
         set_up();
 
-        let guard = DOCKER_COMPOSE_ENV.read().unwrap();
-        let docker_compose = guard.as_ref().unwrap();
-        let container_ip = docker_compose.get_container_ip("minio");
+        let container_ip = get_container_ip("minio");
         let minio_socket_addr = SocketAddr::new(container_ip, MINIO_PORT);
 
         FileIOBuilder::new("s3")
@@ -67,6 +69,12 @@ mod tests {
             .unwrap()
     }
 
+    fn get_container_ip(service_name: &str) -> IpAddr {
+        let guard = DOCKER_COMPOSE_ENV.read().unwrap();
+        let docker_compose = guard.as_ref().unwrap();
+        docker_compose.get_container_ip(service_name)
+    }
+
     #[tokio::test]
     async fn test_file_io_s3_exists() {
         let file_io = get_file_io().await;
@@ -100,4 +108,152 @@ mod tests {
             assert_eq!(buffer, "test_input".as_bytes());
         }
     }
+
+    // Mock credential loader for testing
+    struct MockCredentialLoader {
+        credential: Option<AwsCredential>,
+    }
+
+    impl MockCredentialLoader {
+        fn new(credential: Option<AwsCredential>) -> Self {
+            Self { credential }
+        }
+
+        fn new_minio() -> Self {
+            Self::new(Some(AwsCredential {
+                access_key_id: "admin".to_string(),
+                secret_access_key: "password".to_string(),
+                session_token: None,
+                expires_in: None,
+            }))
+        }
+    }
+
+    #[async_trait]
+    impl AwsCredentialLoad for MockCredentialLoader {
+        async fn load_credential(&self, _client: Client) -> 
anyhow::Result<Option<AwsCredential>> {
+            Ok(self.credential.clone())
+        }
+    }
+
+    #[test]
+    fn test_file_io_builder_extension_system() {
+        // Test adding and retrieving extensions
+        let test_string = "test_extension_value".to_string();
+        let builder = 
FileIOBuilder::new_fs_io().with_extension(test_string.clone());
+
+        // Test retrieving the extension
+        let extension: Option<Arc<String>> = builder.extension();
+        assert!(extension.is_some());
+        assert_eq!(*extension.unwrap(), test_string);
+
+        // Test that non-existent extension returns None
+        let non_existent: Option<Arc<i32>> = builder.extension();
+        assert!(non_existent.is_none());
+    }
+
+    #[test]
+    fn test_file_io_builder_multiple_extensions() {
+        // Test adding multiple different types of extensions
+        let test_string = "test_value".to_string();
+        let test_number = 42i32;
+
+        let builder = FileIOBuilder::new_fs_io()
+            .with_extension(test_string.clone())
+            .with_extension(test_number);
+
+        // Retrieve both extensions
+        let string_ext: Option<Arc<String>> = builder.extension();
+        let number_ext: Option<Arc<i32>> = builder.extension();
+
+        assert!(string_ext.is_some());
+        assert!(number_ext.is_some());
+        assert_eq!(*string_ext.unwrap(), test_string);
+        assert_eq!(*number_ext.unwrap(), test_number);
+    }
+
+    #[test]
+    fn test_custom_aws_credential_loader_instantiation() {
+        // Test creating CustomAwsCredentialLoader with mock loader
+        let mock_loader = MockCredentialLoader::new_minio();
+        let custom_loader = 
CustomAwsCredentialLoader::new(Arc::new(mock_loader));
+
+        // Test that the loader can be used in FileIOBuilder
+        let builder = FileIOBuilder::new("s3")
+            .with_extension(custom_loader.clone())
+            .with_props(vec![
+                (S3_ENDPOINT, "http://localhost:9000".to_string()),
+                ("bucket", "test-bucket".to_string()),
+                (S3_REGION, "us-east-1".to_string()),
+            ]);
+
+        // Verify the extension was stored
+        let retrieved_loader: Option<Arc<CustomAwsCredentialLoader>> = 
builder.extension();
+        assert!(retrieved_loader.is_some());
+    }
+
+    #[tokio::test]
+    async fn test_s3_with_custom_credential_loader_integration() {
+        let _file_io = get_file_io().await;
+
+        // Create a mock credential loader
+        let mock_loader = MockCredentialLoader::new_minio();
+        let custom_loader = 
CustomAwsCredentialLoader::new(Arc::new(mock_loader));
+
+        // Get container info for endpoint
+        let container_ip = get_container_ip("minio");
+        let minio_socket_addr = SocketAddr::new(container_ip, MINIO_PORT);
+
+        // Build FileIO with custom credential loader
+        let file_io_with_custom_creds = FileIOBuilder::new("s3")
+            .with_extension(custom_loader)
+            .with_props(vec![
+                (S3_ENDPOINT, format!("http://{}";, minio_socket_addr)),
+                (S3_REGION, "us-east-1".to_string()),
+            ])
+            .build()
+            .unwrap();
+
+        // Test that the FileIO was built successfully with the custom loader
+        match file_io_with_custom_creds.exists("s3://bucket1/any").await {
+            Ok(_) => {}
+            Err(e) => panic!("Failed to check existence of bucket: {e}"),
+        }
+    }
+
+    #[tokio::test]
+    async fn test_s3_with_custom_credential_loader_integration_failure() {
+        let _file_io = get_file_io().await;
+
+        // Create a mock credential loader with no credentials
+        let mock_loader = MockCredentialLoader::new(None);
+        let custom_loader = 
CustomAwsCredentialLoader::new(Arc::new(mock_loader));
+
+        // Get container info for endpoint
+        let container_ip = get_container_ip("minio");
+        let minio_socket_addr = SocketAddr::new(container_ip, MINIO_PORT);
+
+        // Build FileIO with custom credential loader
+        let file_io_with_custom_creds = FileIOBuilder::new("s3")
+            .with_extension(custom_loader)
+            .with_props(vec![
+                (S3_ENDPOINT, format!("http://{}";, minio_socket_addr)),
+                (S3_REGION, "us-east-1".to_string()),
+            ])
+            .build()
+            .unwrap();
+
+        // Test that the FileIO was built successfully with the custom loader
+        match file_io_with_custom_creds.exists("s3://bucket1/any").await {
+            Ok(_) => panic!(
+                "Expected error, but got Ok - the credential loader should 
fail to provide valid credentials"
+            ),
+            Err(e) => {
+                assert!(
+                    e.to_string()
+                        .contains("no valid credential found and anonymous 
access is not allowed")
+                );
+            }
+        }
+    }
 }

Reply via email to