pwrliang commented on code in PR #586: URL: https://github.com/apache/sedona-db/pull/586#discussion_r2805824007
########## c/sedona-libgpuspatial/src/libgpuspatial.rs: ########## @@ -0,0 +1,583 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::GpuSpatialError; +#[cfg(gpu_available)] +use crate::libgpuspatial_glue_bindgen::*; +use crate::predicate::GpuSpatialRelationPredicate; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::ffi::FFI_ArrowSchema; +use arrow_schema::DataType; +use std::convert::TryFrom; +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; +use std::sync::{Arc, Mutex}; + +// ---------------------------------------------------------------------- +// Runtime Wrapper +// ---------------------------------------------------------------------- + +pub struct GpuSpatialRuntimeWrapper { + runtime: GpuSpatialRuntime, +} + +impl GpuSpatialRuntimeWrapper { + pub fn try_new( + device_id: i32, + ptx_root: &str, + use_cuda_memory_pool: bool, + cuda_memory_pool_init_precent: i32, + ) -> Result<GpuSpatialRuntimeWrapper, GpuSpatialError> { + let mut runtime = GpuSpatialRuntime { + init: None, + release: None, + get_last_error: None, + private_data: std::ptr::null_mut(), + }; + + unsafe { + GpuSpatialRuntimeCreate(&mut runtime); + } + + if let Some(init_fn) = runtime.init { + let c_ptx_root = CString::new(ptx_root).map_err(|_| { + GpuSpatialError::Init("Failed to convert ptx_root to CString".into()) + })?; + + let mut config = GpuSpatialRuntimeConfig { + device_id, + ptx_root: c_ptx_root.as_ptr(), + use_cuda_memory_pool, + cuda_memory_pool_init_precent, + }; + + unsafe { + let get_last_error = runtime.get_last_error; + let runtime_ptr = &mut runtime as *mut GpuSpatialRuntime; + + check_ffi_call( + move || init_fn(runtime_ptr as *mut _, &mut config), + get_last_error, + runtime_ptr, + GpuSpatialError::Init, + )?; + } + } + + Ok(GpuSpatialRuntimeWrapper { runtime }) + } +} + +impl Drop for GpuSpatialRuntimeWrapper { + fn drop(&mut self) { + if let Some(release_fn) = self.runtime.release { + unsafe { + release_fn(&mut self.runtime as *mut _); + } + } + } +} + +// ---------------------------------------------------------------------- +// Spatial Index - Internal Wrapper +// ---------------------------------------------------------------------- + +/// Internal wrapper that manages the lifecycle of the C `SedonaFloatIndex2D` struct. +/// It is wrapped in an `Arc` by the public structs to ensure thread safety. +struct FloatIndex2DWrapper { + index: SedonaFloatIndex2D, + // Keep a reference to the RT engine to ensure it lives as long as the index + _runtime: Arc<Mutex<GpuSpatialRuntimeWrapper>>, +} + +// The C library is designed for thread safety when used correctly (separate contexts per thread) +unsafe impl Send for FloatIndex2DWrapper {} +unsafe impl Sync for FloatIndex2DWrapper {} + +impl Drop for FloatIndex2DWrapper { + fn drop(&mut self) { + if let Some(release_fn) = self.index.release { + unsafe { + release_fn(&mut self.index as *mut _); + } + } + } +} + +// ---------------------------------------------------------------------- +// Spatial Index - Builder +// ---------------------------------------------------------------------- + +/// Builder for the Spatial Index. This struct has exclusive ownership +/// and is not thread-safe (Send but not Sync) because building is a +/// single-threaded operation. +pub struct FloatIndex2DBuilder { + inner: FloatIndex2DWrapper, +} + +impl FloatIndex2DBuilder { + pub fn try_new( + runtime: Arc<Mutex<GpuSpatialRuntimeWrapper>>, + concurrency: u32, + ) -> Result<Self, GpuSpatialError> { + let mut index = SedonaFloatIndex2D { + clear: None, + create_context: None, + destroy_context: None, + push_build: None, + finish_building: None, + probe: None, + get_build_indices_buffer: None, + get_probe_indices_buffer: None, + get_last_error: None, + context_get_last_error: None, + release: None, + private_data: std::ptr::null_mut(), + }; + + let mut engine_guard = runtime + .lock() + .map_err(|_| GpuSpatialError::Init("Failed to acquire mutex lock".to_string()))?; + + let config = GpuSpatialIndexConfig { + runtime: &mut engine_guard.runtime, + concurrency, + }; + + unsafe { + if GpuSpatialIndexFloat2DCreate(&mut index, &config) != 0 { + let msg = if let Some(get_err) = index.get_last_error { + CStr::from_ptr(get_err(&index as *const _ as *mut _)) + .to_string_lossy() + .into_owned() + } else { + "Unknown error during Index Create".into() + }; + return Err(GpuSpatialError::Init(msg)); + } + } + + Ok(FloatIndex2DBuilder { + inner: FloatIndex2DWrapper { + index, + _runtime: runtime.clone(), + }, + }) + } + + pub fn clear(&mut self) { + if let Some(clear_fn) = self.inner.index.clear { + unsafe { + clear_fn(&mut self.inner.index as *mut _); + } + } + } + + pub unsafe fn push_build( + &mut self, + buf: *const f32, + n_rects: u32, + ) -> Result<(), GpuSpatialError> { + if let Some(push_build_fn) = self.inner.index.push_build { + let get_last_error = self.inner.index.get_last_error; + let index_ptr = &mut self.inner.index as *mut _; + + check_ffi_call( + move || push_build_fn(index_ptr, buf, n_rects), + get_last_error, + index_ptr, + GpuSpatialError::PushBuild, + )?; + } + Ok(()) + } + + /// Consumes the builder and returns a shared, thread-safe index wrapper. + pub fn finish(mut self) -> Result<SharedFloatIndex2D, GpuSpatialError> { + if let Some(finish_building_fn) = self.inner.index.finish_building { + // Extract to local vars + let get_last_error = self.inner.index.get_last_error; + let index_ptr = &mut self.inner.index as *mut _; + + unsafe { + check_ffi_call( + move || finish_building_fn(index_ptr), + get_last_error, + index_ptr, + GpuSpatialError::FinishBuild, + )?; + } + } + + Ok(SharedFloatIndex2D { + inner: Arc::new(self.inner), + }) + } +} + +// ---------------------------------------------------------------------- +// Spatial Index - Shared Read-Only Index +// ---------------------------------------------------------------------- + +/// Thread-safe wrapper around the built index. +/// Used to spawn thread-local contexts for probing. +#[derive(Clone)] +pub struct SharedFloatIndex2D { + inner: Arc<FloatIndex2DWrapper>, +} + +unsafe impl Send for SharedFloatIndex2D {} +unsafe impl Sync for SharedFloatIndex2D {} + +impl SharedFloatIndex2D { + pub fn create_context(&self) -> Result<FloatIndex2DContext, GpuSpatialError> { + let mut ctx = SedonaSpatialIndexContext { + private_data: std::ptr::null_mut(), + }; + + if let Some(create_context_fn) = self.inner.index.create_context { + unsafe { + create_context_fn(&mut ctx); + } + } + + Ok(FloatIndex2DContext { + inner: self.inner.clone(), + context: ctx, + }) + } + /// Probes the index using the provided thread-local context. + /// The context is modified to contain the result buffers. + pub unsafe fn probe( + &self, + ctx: &mut FloatIndex2DContext, + buf: *const f32, + n_rects: u32, + ) -> Result<(), GpuSpatialError> { + if let Some(probe_fn) = self.inner.index.probe { + // Get mutable pointer to the index (C API requirement, safe due to internal locking/context usage) + let index_ptr = &self.inner.index as *const _ as *mut SedonaFloatIndex2D; + + // Pass the context from the argument + if probe_fn(index_ptr, &mut ctx.context, buf, n_rects) != 0 { + let error_string = + if let Some(get_ctx_err) = self.inner.index.context_get_last_error { + CStr::from_ptr(get_ctx_err(&mut ctx.context)) + .to_string_lossy() + .into_owned() + } else { + "Unknown context error".to_string() + }; + return Err(GpuSpatialError::Probe(error_string)); + } + } + Ok(()) + } +} + +/// Thread-local context for probing the index. +/// This struct is Send (can be moved between threads) but NOT Sync. +pub struct FloatIndex2DContext { + inner: Arc<FloatIndex2DWrapper>, // Shared reference to the index wrapper to ensure it lives as long as the context + context: SedonaSpatialIndexContext, // The actual C context struct that holds thread-local state and result buffers +} + +unsafe impl Send for FloatIndex2DContext {} + +impl FloatIndex2DContext { + fn get_indices_buffer_helper( + &mut self, + func: Option<unsafe extern "C" fn(*mut SedonaSpatialIndexContext, *mut *mut u32, *mut u32)>, + ) -> &[u32] { + if let Some(f) = func { + let mut ptr: *mut u32 = std::ptr::null_mut(); + let mut len: u32 = 0; + unsafe { + f(&mut self.context, &mut ptr, &mut len); + if len > 0 && !ptr.is_null() { + return std::slice::from_raw_parts(ptr, len as usize); + } + } + } + &[] + } + + pub fn get_build_indices_buffer(&mut self) -> &[u32] { + self.get_indices_buffer_helper(self.inner.index.get_build_indices_buffer) + } + + pub fn get_probe_indices_buffer(&mut self) -> &[u32] { + self.get_indices_buffer_helper(self.inner.index.get_probe_indices_buffer) + } +} + +impl Drop for FloatIndex2DContext { + fn drop(&mut self) { + if let Some(destroy_context_fn) = self.inner.index.destroy_context { + unsafe { + destroy_context_fn(&mut self.context); + } + } + } +} + +struct RefinerWrapper { + refiner: SedonaSpatialRefiner, + _runtime: Arc<Mutex<GpuSpatialRuntimeWrapper>>, +} + +unsafe impl Send for RefinerWrapper {} +unsafe impl Sync for RefinerWrapper {} + +impl Drop for RefinerWrapper { + fn drop(&mut self) { + if let Some(release_fn) = self.refiner.release { + unsafe { + release_fn(&mut self.refiner as *mut _); + } + } + } +} + +pub struct GpuSpatialRefinerBuilder { + inner: RefinerWrapper, +} Review Comment: Refiner can be used in parallel, just like the index. The geometries from the build side are pushed to the refiner, and will be shared by all partitions on the stream side. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
