zeroshade commented on code in PR #3099:
URL: https://github.com/apache/arrow-adbc/pull/3099#discussion_r2190476059


##########
rust/core/src/driver_manager.rs:
##########
@@ -1310,3 +1461,489 @@ impl Drop for ManagedStatement {
         unsafe { method(statement.deref_mut(), null_mut()) };
     }
 }
+
+const fn current_arch() -> &'static str {
+    #[cfg(target_arch = "x86_64")]
+    const ARCH: &str = "amd64";
+    #[cfg(target_arch = "aarch64")]
+    const ARCH: &str = "arm64";
+    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
+    const ARCH: &str = std::env::consts::ARCH;
+
+    #[cfg(target_os = "macos")]
+    const OS: &str = "osx";
+    #[cfg(not(target_os = "macos"))]
+    const OS: &str = std::env::consts::OS;
+
+    #[cfg(target_env = "musl")]
+    const EXTRA: &str = "_musl";
+    #[cfg(all(target_os = "windows", target_env = "gnu"))]
+    const EXTRA: &str = "_mingw";
+    #[cfg(not(any(target_env = "musl", all(target_os = "windows", target_env = 
"gnu"))))]
+    const EXTRA: &str = "";
+
+    concat!(OS, "_", ARCH, EXTRA)
+}
+
+#[cfg(target_os = "windows")]
+extern crate windows_sys as windows;
+
+#[cfg(target_os = "windows")]
+mod target_windows {
+    use std::ffi::c_void;
+    use std::ffi::OsString;
+    use std::os::windows::ffi::OsStringExt;
+    use std::path::PathBuf;
+    use std::slice;
+
+    use super::windows::Win32::UI::Shell;
+
+    fn user_config_dir() -> Option<PathBuf> {
+        unsafe {
+            let mut path_ptr: windows::core::PWSTR = std::ptr::null_mut();
+            let result = Shell::SHGetKnownFolderPath(
+                Shell::FOLDERID_LocalAppData,
+                0,
+                std::ptr::null_mut(),
+                &mut path_ptr,
+            );
+
+            if result == 0 {
+                let len = windows::Win32::Globalization::lstrlenW(path_ptr) as 
usize;
+                let path = slice::from_raw_parts(path_ptr, len);
+                let ostr: OsString = OsStringExt::from_wide(path);
+                windows::Win32::System::Com::CoTaskMemFree(path_ptr as *const 
c_void);
+                Some(PathBuf::from(ostr))
+            } else {
+                windows::Win32::System::Com::CoTaskMemFree(path_ptr as *const 
c_void);
+                None
+            }
+        }
+    }
+}
+
+fn user_config_dir() -> Option<PathBuf> {
+    #[cfg(target_os = "windows")]
+    {
+        use target_windows::user_config_dir;
+        user_config_dir().and_then(|path| {
+            path.push("ADBC");
+            path.push("drivers");
+            Some(path)
+        })
+    }
+
+    #[cfg(target_os = "macos")]
+    {
+        env::var_os("HOME").map(PathBuf::from).and_then(|mut path| {
+            path.push("Library");
+            path.push("Application Support");
+            path.push("ADBC");
+            Some(path)
+        })
+    }
+
+    #[cfg(all(unix, not(target_os = "macos")))]
+    {
+        env::var_os("XDG_CONFIG_HOME")
+            .map(PathBuf::from)
+            .or_else(|| {
+                env::var_os("HOME").map(|home| {
+                    let mut path = PathBuf::from(home);
+                    path.push(".config");
+                    path
+                })
+            })
+            .and_then(|mut path| {
+                path.push("adbc");
+                Some(path)
+            })
+    }
+}
+
+fn get_search_paths(lvls: LoadFlags) -> Vec<PathBuf> {
+    let mut result = Vec::new();
+    if lvls & LOAD_FLAG_SEARCH_ENV != 0 {
+        env::var_os("ADBC_CONFIG_PATH").and_then(|paths| {
+            for p in env::split_paths(&paths) {
+                result.push(p);
+            }
+            Some(())
+        });
+    }
+
+    if lvls & LOAD_FLAG_SEARCH_USER != 0 {
+        if let Some(path) = user_config_dir() {
+            if path.exists() {
+                result.push(path);
+            }
+        }
+    }
+
+    // system level for windows is to search the registry keys
+    #[cfg(not(windows))]
+    if lvls & LOAD_FLAG_SEARCH_SYSTEM != 0 {
+        let system_config_dir = PathBuf::from("/etc/adbc");
+        if system_config_dir.exists() {
+            result.push(system_config_dir);
+        }
+    }
+
+    result
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    use crate::LOAD_FLAG_DEFAULT;
+    use tempfile::Builder;
+
+    fn simple_manifest() -> toml::Table {
+        // if this test is enabled, we expect the env var 
ADBC_DRIVER_MANAGER_TEST_LIB
+        // to be defined.
+        let driver_path =
+            PathBuf::from(env::var_os("ADBC_DRIVER_MANAGER_TEST_LIB").expect(
+                "ADBC_DRIVER_MANAGER_TEST_LIB must be set for driver manager 
manifest tests",
+            ));
+
+        assert!(
+            driver_path.exists(),
+            "ADBC_DRIVER_MANAGER_TEST_LIB path does not exist: {}",
+            driver_path.display()
+        );
+
+        let arch = current_arch();
+        format!(
+            r#"
+    name = 'SQLite3'
+    publisher = 'arrow-adbc'
+    version = '1.0.0'
+
+    [ADBC]
+    version = '1.1.0'
+
+    [Driver]
+    [Driver.shared]
+    {arch} = {driver_path:?}
+    "#
+        )
+        .parse::<toml::Table>()
+        .unwrap()
+    }
+
+    fn write_manifest_to_tempfile(p: PathBuf, tbl: toml::Table) -> 
(tempfile::TempDir, PathBuf) {
+        let tmp_dir = Builder::new()
+            .prefix("adbc_tests")
+            .tempdir()
+            .expect("Failed to create temporary directory for driver manager 
manifest test");
+
+        let manifest_path = tmp_dir.path().join(p);
+        if let Some(parent) = manifest_path.parent() {
+            std::fs::create_dir_all(parent)
+                .expect("Failed to create parent directory for manifest");
+        }
+
+        std::fs::write(&manifest_path, toml::to_string_pretty(&tbl).unwrap())
+            .expect("Failed to write driver manager manifest to temporary 
file");
+
+        (tmp_dir, manifest_path)
+    }
+
+    #[test]
+    #[cfg_attr(not(feature = "driver_manager_test_lib"), ignore)]
+    fn test_load_driver_env() {
+        // ensure that we fail without the env var set
+        assert!(ManagedDriver::find_load_from_name(
+            "sqlite",
+            None,
+            AdbcVersion::V100,
+            LOAD_FLAG_SEARCH_ENV
+        )
+        .is_err());

Review Comment:
   originally i tried using that, but got an error that 
`driver_manager::ManagedDriver` doesn't implement `Debug` trait. So i guess i 
should add the derive for that



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to