zeroshade commented on code in PR #3099: URL: https://github.com/apache/arrow-adbc/pull/3099#discussion_r2190476059
########## rust/core/src/driver_manager.rs: ########## @@ -1310,3 +1461,489 @@ impl Drop for ManagedStatement { unsafe { method(statement.deref_mut(), null_mut()) }; } } + +const fn current_arch() -> &'static str { + #[cfg(target_arch = "x86_64")] + const ARCH: &str = "amd64"; + #[cfg(target_arch = "aarch64")] + const ARCH: &str = "arm64"; + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + const ARCH: &str = std::env::consts::ARCH; + + #[cfg(target_os = "macos")] + const OS: &str = "osx"; + #[cfg(not(target_os = "macos"))] + const OS: &str = std::env::consts::OS; + + #[cfg(target_env = "musl")] + const EXTRA: &str = "_musl"; + #[cfg(all(target_os = "windows", target_env = "gnu"))] + const EXTRA: &str = "_mingw"; + #[cfg(not(any(target_env = "musl", all(target_os = "windows", target_env = "gnu"))))] + const EXTRA: &str = ""; + + concat!(OS, "_", ARCH, EXTRA) +} + +#[cfg(target_os = "windows")] +extern crate windows_sys as windows; + +#[cfg(target_os = "windows")] +mod target_windows { + use std::ffi::c_void; + use std::ffi::OsString; + use std::os::windows::ffi::OsStringExt; + use std::path::PathBuf; + use std::slice; + + use super::windows::Win32::UI::Shell; + + fn user_config_dir() -> Option<PathBuf> { + unsafe { + let mut path_ptr: windows::core::PWSTR = std::ptr::null_mut(); + let result = Shell::SHGetKnownFolderPath( + Shell::FOLDERID_LocalAppData, + 0, + std::ptr::null_mut(), + &mut path_ptr, + ); + + if result == 0 { + let len = windows::Win32::Globalization::lstrlenW(path_ptr) as usize; + let path = slice::from_raw_parts(path_ptr, len); + let ostr: OsString = OsStringExt::from_wide(path); + windows::Win32::System::Com::CoTaskMemFree(path_ptr as *const c_void); + Some(PathBuf::from(ostr)) + } else { + windows::Win32::System::Com::CoTaskMemFree(path_ptr as *const c_void); + None + } + } + } +} + +fn user_config_dir() -> Option<PathBuf> { + #[cfg(target_os = "windows")] + { + use target_windows::user_config_dir; + user_config_dir().and_then(|path| { + path.push("ADBC"); + path.push("drivers"); + Some(path) + }) + } + + #[cfg(target_os = "macos")] + { + env::var_os("HOME").map(PathBuf::from).and_then(|mut path| { + path.push("Library"); + path.push("Application Support"); + path.push("ADBC"); + Some(path) + }) + } + + #[cfg(all(unix, not(target_os = "macos")))] + { + env::var_os("XDG_CONFIG_HOME") + .map(PathBuf::from) + .or_else(|| { + env::var_os("HOME").map(|home| { + let mut path = PathBuf::from(home); + path.push(".config"); + path + }) + }) + .and_then(|mut path| { + path.push("adbc"); + Some(path) + }) + } +} + +fn get_search_paths(lvls: LoadFlags) -> Vec<PathBuf> { + let mut result = Vec::new(); + if lvls & LOAD_FLAG_SEARCH_ENV != 0 { + env::var_os("ADBC_CONFIG_PATH").and_then(|paths| { + for p in env::split_paths(&paths) { + result.push(p); + } + Some(()) + }); + } + + if lvls & LOAD_FLAG_SEARCH_USER != 0 { + if let Some(path) = user_config_dir() { + if path.exists() { + result.push(path); + } + } + } + + // system level for windows is to search the registry keys + #[cfg(not(windows))] + if lvls & LOAD_FLAG_SEARCH_SYSTEM != 0 { + let system_config_dir = PathBuf::from("/etc/adbc"); + if system_config_dir.exists() { + result.push(system_config_dir); + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::LOAD_FLAG_DEFAULT; + use tempfile::Builder; + + fn simple_manifest() -> toml::Table { + // if this test is enabled, we expect the env var ADBC_DRIVER_MANAGER_TEST_LIB + // to be defined. + let driver_path = + PathBuf::from(env::var_os("ADBC_DRIVER_MANAGER_TEST_LIB").expect( + "ADBC_DRIVER_MANAGER_TEST_LIB must be set for driver manager manifest tests", + )); + + assert!( + driver_path.exists(), + "ADBC_DRIVER_MANAGER_TEST_LIB path does not exist: {}", + driver_path.display() + ); + + let arch = current_arch(); + format!( + r#" + name = 'SQLite3' + publisher = 'arrow-adbc' + version = '1.0.0' + + [ADBC] + version = '1.1.0' + + [Driver] + [Driver.shared] + {arch} = {driver_path:?} + "# + ) + .parse::<toml::Table>() + .unwrap() + } + + fn write_manifest_to_tempfile(p: PathBuf, tbl: toml::Table) -> (tempfile::TempDir, PathBuf) { + let tmp_dir = Builder::new() + .prefix("adbc_tests") + .tempdir() + .expect("Failed to create temporary directory for driver manager manifest test"); + + let manifest_path = tmp_dir.path().join(p); + if let Some(parent) = manifest_path.parent() { + std::fs::create_dir_all(parent) + .expect("Failed to create parent directory for manifest"); + } + + std::fs::write(&manifest_path, toml::to_string_pretty(&tbl).unwrap()) + .expect("Failed to write driver manager manifest to temporary file"); + + (tmp_dir, manifest_path) + } + + #[test] + #[cfg_attr(not(feature = "driver_manager_test_lib"), ignore)] + fn test_load_driver_env() { + // ensure that we fail without the env var set + assert!(ManagedDriver::find_load_from_name( + "sqlite", + None, + AdbcVersion::V100, + LOAD_FLAG_SEARCH_ENV + ) + .is_err()); Review Comment: originally i tried using that, but got an error that `driver_manager::ManagedDriver` doesn't implement `Debug` trait. So i guess i should add the derive for that -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org