This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-experiments.git
The following commit(s) were added to refs/heads/main by this push:
new 05e4e88 add cudf-flight-ucx example (#28)
05e4e88 is described below
commit 05e4e888b19dbd98b95d8984a8d1f97fb6570d00
Author: Matt Topol <[email protected]>
AuthorDate: Thu Apr 25 13:25:09 2024 -0400
add cudf-flight-ucx example (#28)
* add cudf-flight-ucx example
* Apply suggestions from code review
Co-authored-by: Sutou Kouhei <[email protected]>
* Update dissociated-ipc/cudf-flight-poc.cc
Co-authored-by: Sutou Kouhei <[email protected]>
* ran linting
* Apply suggestions from code review
* split poc file for readability
* Update dissociated-ipc/README.md
Co-authored-by: Sutou Kouhei <[email protected]>
* rename files
---------
Co-authored-by: Sutou Kouhei <[email protected]>
Co-authored-by: Ian Cook <[email protected]>
---
.clang-format | 21 ++
.gitignore | 21 ++
data/taxi-data/README.md | 22 ++
data/taxi-data/train.parquet | 3 +
dissociated-ipc/CMakeLists.txt | 112 ++++++++++
dissociated-ipc/README.md | 55 +++++
dissociated-ipc/cudf-flight-client.cc | 384 ++++++++++++++++++++++++++++++++
dissociated-ipc/cudf-flight-server.cc | 408 ++++++++++++++++++++++++++++++++++
dissociated-ipc/cudf-flight-ucx.cc | 39 ++++
dissociated-ipc/cudf-flight-ucx.h | 38 ++++
dissociated-ipc/ucx_client.cc | 73 ++++++
dissociated-ipc/ucx_client.h | 40 ++++
dissociated-ipc/ucx_conn.cc | 355 +++++++++++++++++++++++++++++
dissociated-ipc/ucx_conn.h | 90 ++++++++
dissociated-ipc/ucx_server.cc | 280 +++++++++++++++++++++++
dissociated-ipc/ucx_server.h | 88 ++++++++
dissociated-ipc/ucx_utils.cc | 287 ++++++++++++++++++++++++
dissociated-ipc/ucx_utils.h | 122 ++++++++++
18 files changed, 2438 insertions(+)
diff --git a/.clang-format b/.clang-format
new file mode 100644
index 0000000..9448dc8
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+---
+BasedOnStyle: Google
+ColumnLimit: 90
+DerivePointerAlignment: false
+IncludeBlocks: Preserve
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..d997483
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+vendored
+build
+.vscode
+cufile.log
diff --git a/data/taxi-data/README.md b/data/taxi-data/README.md
new file mode 100644
index 0000000..6a7416e
--- /dev/null
+++ b/data/taxi-data/README.md
@@ -0,0 +1,22 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# taxi-data
+
+A small subset of the public [NYC Taxi
Data](https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page) used in the
dissociated-ipc example.
diff --git a/data/taxi-data/train.parquet b/data/taxi-data/train.parquet
new file mode 100755
index 0000000..7bf702b
--- /dev/null
+++ b/data/taxi-data/train.parquet
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:854cf53ab8669aa260a8ae65beafe880ab1a0232dbdac09705fb9b6f3f84eacd
+size 38521857
diff --git a/dissociated-ipc/CMakeLists.txt b/dissociated-ipc/CMakeLists.txt
new file mode 100644
index 0000000..fa46397
--- /dev/null
+++ b/dissociated-ipc/CMakeLists.txt
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
contributor
+# license agreements. See the NOTICE file distributed with this work for
+# additional information regarding copyright ownership. The ASF licenses this
+# file to you under the Apache License, Version 2.0 (the "License"); you may
not
+# use this file except in compliance with the License. You may obtain a copy
of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+
+cmake_minimum_required(VERSION 3.20)
+message(STATUS "Building using CMake version: ${CMAKE_VERSION}")
+project(arrow-cudf-flight CXX CUDA)
+
+include(CMakeParseArguments)
+
+# https://www.cmake.org/cmake/help/latest/policy/CMP0025.html
+#
+# Compiler id for Apple Clang is now AppleClang.
+cmake_policy(SET CMP0025 NEW)
+
+# https://cmake.org/cmake/help/latest/policy/CMP0042.html
+#
+# Enable MACOSX_RPATH by default. @rpath in a target's install name is a more
+# flexible and powerful mechanism than @executable_path or @loader_path for
+# locating shared libraries.
+cmake_policy(SET CMP0042 NEW)
+
+# https://www.cmake.org/cmake/help/latest/policy/CMP0054.html
+#
+# Only interpret if() arguments as variables or keywords when unquoted.
+cmake_policy(SET CMP0054 NEW)
+
+# https://www.cmake.org/cmake/help/latest/policy/CMP0057.html
+#
+# Support new if() IN_LIST operator.
+cmake_policy(SET CMP0057 NEW)
+
+# https://www.cmake.org/cmake/help/latest/policy/CMP0063.html
+#
+# Adapted from Apache Kudu:
+# https://github.com/apache/kudu/commit/bd549e13743a51013585 Honor visibility
+# properties for all target types.
+cmake_policy(SET CMP0063 NEW)
+
+# https://cmake.org/cmake/help/latest/policy/CMP0068.html
+#
+# RPATH settings on macOS do not affect install_name.
+cmake_policy(SET CMP0068 NEW)
+
+# https://cmake.org/cmake/help/latest/policy/CMP0074.html
+#
+# find_package() uses <PackageName>_ROOT variables.
+cmake_policy(SET CMP0074 NEW)
+
+# https://cmake.org/cmake/help/latest/policy/CMP0091.html
+#
+# MSVC runtime library flags are selected by an abstraction.
+cmake_policy(SET CMP0091 NEW)
+
+# https://cmake.org/cmake/help/latest/policy/CMP0135.html
+#
+# CMP0135 is for solving re-building and re-downloading. We don't have a real
+# problem with the OLD behavior for now but we use the NEW behavior explicitly
+# to suppress CMP0135 warnings.
+if(POLICY CMP0135)
+ cmake_policy(SET CMP0135 NEW)
+endif()
+
+find_package(ArrowFlight REQUIRED)
+find_package(ArrowCUDA REQUIRED)
+message(STATUS "Found Arrow: ${ARROW_VERSION}")
+
+find_package(CUDA QUIET REQUIRED)
+find_package(gflags REQUIRED)
+find_package(cudf REQUIRED)
+find_package(ucx REQUIRED)
+add_library(ucx::ucx INTERFACE IMPORTED)
+target_link_libraries(ucx::ucx INTERFACE ucx::ucp ucx::uct ucx::ucs)
+
+add_definitions(-DFMT_USE_NONTYPE_TEMPLATE_ARGS=0)
+add_executable(
+ arrow-cudf-flight
+ cudf-flight-ucx.cc
+ cudf-flight-client.cc
+ cudf-flight-server.cc
+ ucx_utils.cc
+ ucx_server.cc
+ ucx_client.cc
+ ucx_conn.cc)
+target_link_libraries(
+ arrow-cudf-flight
+ arrow_shared
+ arrow_cuda_shared
+ arrow_flight_shared
+ gflags
+ cudf::cudf
+ ucx::ucx)
+set_target_properties(
+ arrow-cudf-flight
+ PROPERTIES BUILD_RPATH "\$ORIGIN"
+ INSTALL_RPATH "\$ORIGIN"
+ CXX_STANDARD 20
+ CXX_STANDARD_REQUIRED ON
+ CXX_EXTENSIONS ON
+ CUDA_STANDARD 20
+ CUDA_STANDARD_REQUIRED ON)
diff --git a/dissociated-ipc/README.md b/dissociated-ipc/README.md
new file mode 100644
index 0000000..b84624d
--- /dev/null
+++ b/dissociated-ipc/README.md
@@ -0,0 +1,55 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Arrow Dissociated IPC Protocol Example
+
+This directory contains a reference example implementation of the
+[Arrow Dissociated IPC
Protocol](https://arrow.apache.org/docs/dev/format/DissociatedIPC.html).
+
+This protocol splits the Arrow Flatbuffers IPC metadata and the body buffers
+into separate streams to allow for utilizing shared memory, non-cpu device
+memory, or remote memory (RDMA) with Arrow formatted datasets.
+
+This example utilizes [libcudf](https://docs.rapids.ai/api) and
+[UCX](https://openucx.readthedocs.io/en/master/#) to transfer Arrow data
+located on an NVIDIA GPU.
+
+## Building
+
+You must have libcudf, libarrow, libarrow_flight, libarrow_cuda, and ucx
+accessible on your `CMAKE_MODULE_PATH`/`CMAKE_PREFIX_PATH` so that `cmake` can
find them.
+
+After that you can simply do the following:
+
+```console
+$ cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
+$ cmake --build build
+```
+
+to build the `arrow-cudf-flight` mainprog.
+
+## Running
+
+You can start the server by just running `arrow-cudf-flight` which will
+default to using `31337` as the Flight port and `127.0.0.1` for the host.
+Both of these can be changed via the `-port` and `-address` gflags
+respectively.
+
+You can run the client by adding the `-client` option when running the
+command.
diff --git a/dissociated-ipc/cudf-flight-client.cc
b/dissociated-ipc/cudf-flight-client.cc
new file mode 100644
index 0000000..d95525b
--- /dev/null
+++ b/dissociated-ipc/cudf-flight-client.cc
@@ -0,0 +1,384 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <future>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <thread>
+#include <utility>
+
+#include <arrow/array.h>
+#include <arrow/flight/client.h>
+#include <arrow/gpu/cuda_api.h>
+#include <arrow/ipc/api.h>
+#include <arrow/util/endian.h>
+#include <arrow/util/logging.h>
+#include <arrow/util/uri.h>
+
+#include "cudf-flight-ucx.h"
+#include "ucx_client.h"
+
+namespace flight = arrow::flight;
+namespace ipc = arrow::ipc;
+
+arrow::Result<ucp_tag_t> get_want_data_tag(const arrow::util::Uri& loc) {
+ ARROW_ASSIGN_OR_RAISE(auto query_params, loc.query_items());
+ for (auto& q : query_params) {
+ if (q.first == "want_data") {
+ return std::stoull(q.second);
+ }
+ }
+ return 0;
+}
+
+// utility client class to read a stream of data using the dissociated ipc
+// protocol structure
+class StreamReader {
+ public:
+ StreamReader(utils::Connection* ctrl_cnxn, utils::Connection* data_cnxn)
+ : ctrl_cnxn_{ctrl_cnxn}, data_cnxn_{data_cnxn} {
+ ARROW_UNUSED(ctrl_cnxn_->SetAMHandler(0, this, RecvMsg));
+ }
+
+ void set_data_mem_manager(std::shared_ptr<arrow::MemoryManager> mgr) {
+ if (!mgr) {
+ mm_ = arrow::CPUDevice::Instance()->default_memory_manager();
+ } else {
+ mm_ = std::move(mgr);
+ }
+ }
+
+ arrow::Status Start(ucp_tag_t ctrl_tag, ucp_tag_t data_tag, const
std::string& ident) {
+ // consume the data and metadata streams simultaneously
+ ARROW_RETURN_NOT_OK(ctrl_cnxn_->SendTagSync(ctrl_tag, ident.data(),
ident.size()));
+ ARROW_RETURN_NOT_OK(data_cnxn_->SendTagSync(data_tag, ident.data(),
ident.size()));
+
+ std::thread(&StreamReader::run_data_loop, this).detach();
+ std::thread(&StreamReader::run_meta_loop, this).detach();
+
+ return arrow::Status::OK();
+ }
+
+ arrow::Result<std::shared_ptr<arrow::Schema>> Schema() {
+ // return the schema if we've already pulled it
+ if (schema_) {
+ return schema_;
+ }
+
+ // otherwise the next message should be the schema
+ ARROW_ASSIGN_OR_RAISE(auto msg, NextMsg());
+ ARROW_ASSIGN_OR_RAISE(schema_, ipc::ReadSchema(*msg, &dictionary_memo_));
+ return schema_;
+ }
+
+ arrow::Result<std::shared_ptr<arrow::RecordBatch>> Next() {
+ // we need the schema to read the record batch, also ensuring that
+ // we will retrieve the schema message which should be the first message
+ ARROW_ASSIGN_OR_RAISE(auto schema, Schema());
+ ARROW_ASSIGN_OR_RAISE(auto msg, NextMsg());
+ if (msg) {
+ return ipc::ReadRecordBatch(*msg, schema, &dictionary_memo_,
ipc_options_);
+ }
+ // we've hit the end
+ return nullptr;
+ }
+
+ protected:
+ struct PendingMsg {
+ std::promise<std::unique_ptr<ipc::Message>> p;
+ std::shared_ptr<arrow::Buffer> metadata;
+ std::shared_ptr<arrow::Buffer> body;
+ StreamReader* rdr;
+ };
+
+ // data stream loop handler
+ void run_data_loop() {
+ if (arrow::cuda::IsCudaMemoryManager(*mm_)) {
+ // since we're in a new thread, we need to make sure to push the cuda
context
+ // so that ucx uses the same cuda context as the Arrow data is using,
otherwise
+ // the device pointers aren't valid
+ auto ctx =
*(*arrow::cuda::AsCudaMemoryManager(mm_))->cuda_device()->GetContext();
+ cuCtxPushCurrent(reinterpret_cast<CUcontext>(ctx->handle()));
+ }
+
+ while (true) {
+ // progress the connection until an event happens
+ while (data_cnxn_->Progress()) {
+ }
+ {
+ // check if we have received any metadata which indicate we need to
poll
+ // for a corresponding tagged data message
+ std::unique_lock<std::mutex> guard(polling_mutex_);
+ for (auto it = polling_map_.begin(); it != polling_map_.end();) {
+ auto maybe_tag =
+ data_cnxn_->ProbeForTag(ucp_tag_t(it->first),
0x00000000FFFFFFFF, 1);
+ if (!maybe_tag.ok()) {
+ ARROW_LOG(ERROR) << maybe_tag.status().ToString();
+ return;
+ }
+
+ auto tag_pair = maybe_tag.MoveValueUnsafe();
+ if (tag_pair.second != nullptr) {
+ // got one!
+ auto st = RecvTag(tag_pair.second, tag_pair.first,
std::move(it->second));
+ if (!st.ok()) {
+ ARROW_LOG(ERROR) << st.ToString();
+ return;
+ }
+ it = polling_map_.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+
+ // if the metadata stream has ended...
+ if (finished_metadata_.load()) {
+ // we are done if there's nothing left to poll for and nothing
outstanding
+ std::lock_guard<std::mutex> guard(polling_mutex_);
+ if (polling_map_.empty() && outstanding_tags_.load() == 0) {
+ break;
+ }
+ }
+ }
+ }
+
+ // a mask to grab the byte indicating the body message type.
+ static constexpr uint64_t kbody_mask_ = 0x0100000000000000;
+
+ arrow::Status RecvTag(ucp_tag_message_h msg, ucp_tag_recv_info_t info_tag,
+ PendingMsg pending) {
+ ++outstanding_tags_;
+ ARROW_ASSIGN_OR_RAISE(auto buf, mm_->AllocateBuffer(info_tag.length));
+
+ PendingMsg* new_pending = new PendingMsg(std::move(pending));
+ new_pending->body = std::move(buf);
+ new_pending->rdr = this;
+ return data_cnxn_->RecvTagData(
+ msg, reinterpret_cast<void*>(new_pending->body->address()),
info_tag.length,
+ new_pending,
+ [](void* request, ucs_status_t status, const ucp_tag_recv_info_t*
tag_info,
+ void* user_data) {
+ auto pending =
+
std::unique_ptr<PendingMsg>(reinterpret_cast<PendingMsg*>(user_data));
+ if (status != UCS_OK) {
+ ARROW_LOG(ERROR)
+ << utils::FromUcsStatus("ucp_tag_recv_nbx_callback",
status).ToString();
+ pending->p.set_value(nullptr);
+ return;
+ }
+
+ if (request) ucp_request_free(request);
+
+ if (tag_info->sender_tag & kbody_mask_) {
+ // pointer / offset list body
+ // not yet implemented
+ } else {
+ // full body bytes, use the pending metadata and read our IPC
message
+ // as usual
+ auto msg = *ipc::Message::Open(pending->metadata, pending->body);
+ pending->p.set_value(std::move(msg));
+ --pending->rdr->outstanding_tags_;
+ }
+ },
+ (new_pending->body->is_cpu()) ? UCS_MEMORY_TYPE_HOST :
UCS_MEMORY_TYPE_CUDA);
+ }
+
+ // handle the metadata stream
+ void run_meta_loop() {
+ while (!finished_metadata_.load()) {
+ // progress the connection until we get an event
+ while (ctrl_cnxn_->Progress()) {
+ }
+ {
+ std::unique_lock<std::mutex> guard(queue_mutex_);
+ while (!metadata_queue_.empty()) {
+ // handle any metadata messages in our queue
+ auto buf = std::move(metadata_queue_.front());
+ metadata_queue_.pop();
+ guard.unlock();
+
+ while (buf.wait_for(std::chrono::seconds(0)) !=
std::future_status::ready) {
+ ctrl_cnxn_->Progress();
+ }
+
+ std::shared_ptr<arrow::Buffer> buffer = buf.get();
+ if (static_cast<MetadataMsgType>(buffer->data()[0]) ==
MetadataMsgType::EOS) {
+ finished_metadata_.store(true);
+ guard.lock();
+ continue;
+ }
+
+ uint32_t sequence_number = utils::BytesToUint32LE(buffer->data() +
1);
+ auto metadata = SliceBuffer(buffer, 5, buffer->size() - 5);
+
+ // store a mapping of sequence numbers to std::future that returns
the data
+ std::promise<std::unique_ptr<ipc::Message>> p;
+ {
+ std::lock_guard<std::mutex> lock(msg_mutex_);
+ msg_map_.insert({sequence_number, p.get_future()});
+ }
+ cv_progress_.notify_all();
+
+ auto msg = ipc::Message::Open(metadata, nullptr).ValueOrDie();
+ if (!ipc::Message::HasBody(msg->type())) {
+ p.set_value(std::move(msg));
+ guard.lock();
+ continue;
+ }
+
+ {
+ std::lock_guard<std::mutex> lock(polling_mutex_);
+ polling_map_.insert(
+ {sequence_number, PendingMsg{std::move(p),
std::move(metadata)}});
+ }
+
+ guard.lock();
+ }
+ }
+
+ if (finished_metadata_.load()) break;
+ auto status = utils::FromUcsStatus("ucp_worker_wait",
ctrl_cnxn_->WorkerWait());
+ if (!status.ok()) {
+ ARROW_LOG(ERROR) << status.ToString();
+ return;
+ }
+ }
+ }
+
+ arrow::Result<std::unique_ptr<ipc::Message>> NextMsg() {
+ // fetch the next IPC message by sequence number
+ const uint32_t counter = next_counter_++;
+ std::future<std::unique_ptr<ipc::Message>> futr;
+ {
+ std::unique_lock<std::mutex> lock(msg_mutex_);
+ if (msg_map_.empty() && finished_metadata_.load() &&
!outstanding_tags_.load()) {
+ return nullptr;
+ }
+
+ auto it = msg_map_.find(counter);
+ if (it == msg_map_.end()) {
+ // wait until we get a message for this sequence number
+ cv_progress_.wait(lock, [this, counter, &it] {
+ it = msg_map_.find(counter);
+ return it != msg_map_.end() || finished_metadata_.load();
+ });
+ }
+ futr = std::move(it->second);
+ msg_map_.erase(it);
+ }
+
+ // .get on a future will block until it either recieves a value or fails
+ return futr.get();
+ }
+
+ // callback function to recieve untagged "Active Messages"
+ static ucs_status_t RecvMsg(void* arg, const void* header, size_t header_len,
+ void* data, size_t length,
+ const ucp_am_recv_param_t* param) {
+ StreamReader* rdr = reinterpret_cast<StreamReader*>(arg);
+ DCHECK(length);
+
+ std::promise<std::unique_ptr<arrow::Buffer>> p;
+ {
+ std::lock_guard<std::mutex> lock(rdr->queue_mutex_);
+ rdr->metadata_queue_.push(p.get_future());
+ }
+
+ return rdr->ctrl_cnxn_->RecvAM(std::move(p), header, header_len, data,
length, param);
+ }
+
+ private:
+ utils::Connection* ctrl_cnxn_;
+ utils::Connection* data_cnxn_;
+ std::shared_ptr<arrow::Schema> schema_;
+ ipc::DictionaryMemo dictionary_memo_;
+ ipc::IpcReadOptions ipc_options_;
+
+ std::shared_ptr<arrow::MemoryManager> mm_;
+ std::atomic<bool> finished_metadata_{false};
+ std::atomic<uint32_t> outstanding_tags_{0};
+ uint32_t next_counter_{0};
+
+ std::condition_variable cv_progress_;
+ std::mutex queue_mutex_;
+ std::queue<std::future<std::unique_ptr<arrow::Buffer>>> metadata_queue_;
+ std::mutex polling_mutex_;
+ std::unordered_map<uint32_t, PendingMsg> polling_map_;
+ std::mutex msg_mutex_;
+ std::unordered_map<uint32_t, std::future<std::unique_ptr<ipc::Message>>>
msg_map_;
+};
+
+arrow::Status run_client(const std::string& addr, const int port) {
+ ARROW_ASSIGN_OR_RAISE(auto location, flight::Location::ForGrpcTcp(addr,
port));
+ ARROW_ASSIGN_OR_RAISE(auto client, flight::FlightClient::Connect(location));
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto info,
+
client->GetFlightInfo(flight::FlightDescriptor::Command("train.parquet")));
+ ARROW_LOG(DEBUG) << info->endpoints()[0].locations[0].ToString();
+ ARROW_LOG(DEBUG) << info->endpoints()[0].locations[1].ToString();
+
+ ARROW_ASSIGN_OR_RAISE(auto ctrl_uri, arrow::util::Uri::FromString(
+
info->endpoints()[0].locations[0].ToString()));
+ ARROW_ASSIGN_OR_RAISE(auto data_uri, arrow::util::Uri::FromString(
+
info->endpoints()[0].locations[1].ToString()));
+
+ ARROW_ASSIGN_OR_RAISE(ucp_tag_t ctrl_tag, get_want_data_tag(ctrl_uri));
+ ARROW_ASSIGN_OR_RAISE(ucp_tag_t data_tag, get_want_data_tag(data_uri));
+ const std::string& ident = info->endpoints()[0].ticket.ticket;
+
+ ARROW_ASSIGN_OR_RAISE(auto cuda_mgr,
arrow::cuda::CudaDeviceManager::Instance());
+ ARROW_ASSIGN_OR_RAISE(auto device, cuda_mgr->GetDevice(0));
+ ARROW_ASSIGN_OR_RAISE(auto cuda_device, arrow::cuda::AsCudaDevice(device));
+ ARROW_ASSIGN_OR_RAISE(auto ctx, cuda_device->GetContext());
+ cuCtxPushCurrent(reinterpret_cast<CUcontext>(ctx->handle()));
+
+ ARROW_LOG(DEBUG) << device->ToString();
+
+ UcxClient ctrl_client, data_client;
+ ARROW_RETURN_NOT_OK(ctrl_client.Init(ctrl_uri.host(), ctrl_uri.port()));
+ ARROW_RETURN_NOT_OK(data_client.Init(data_uri.host(), data_uri.port()));
+
+ ARROW_ASSIGN_OR_RAISE(auto ctrl_cnxn, ctrl_client.CreateConn());
+ ARROW_ASSIGN_OR_RAISE(auto data_cnxn, data_client.CreateConn());
+
+ StreamReader rdr(ctrl_cnxn.get(), data_cnxn.get());
+ rdr.set_data_mem_manager(ctx->memory_manager());
+
+ ARROW_RETURN_NOT_OK(rdr.Start(ctrl_tag, data_tag, ident));
+
+ ARROW_ASSIGN_OR_RAISE(auto s, rdr.Schema());
+ std::cout << s->ToString() << std::endl;
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, rdr.Next());
+ if (!batch) {
+ break;
+ }
+
+ std::cout << batch->num_columns() << " " << batch->num_rows() << std::endl;
+ std::cout << batch->column(0)->data()->buffers[1]->device()->ToString() <<
std::endl;
+ ARROW_ASSIGN_OR_RAISE(auto cpubatch,
+ batch->CopyTo(arrow::default_cpu_memory_manager()));
+ std::cout << cpubatch->ToString() << std::endl;
+ }
+
+ ARROW_CHECK_OK(ctrl_cnxn->Close());
+ ARROW_CHECK_OK(data_cnxn->Close());
+ return arrow::Status::OK();
+}
diff --git a/dissociated-ipc/cudf-flight-server.cc
b/dissociated-ipc/cudf-flight-server.cc
new file mode 100644
index 0000000..19b1951
--- /dev/null
+++ b/dissociated-ipc/cudf-flight-server.cc
@@ -0,0 +1,408 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cudf/interop.hpp>
+#include <cudf/io/parquet.hpp>
+
+#include <arrow/c/abi.h>
+#include <arrow/c/bridge.h>
+#include <arrow/flight/server.h>
+#include <arrow/gpu/cuda_api.h>
+#include <arrow/ipc/api.h>
+#include <arrow/util/endian.h>
+#include <arrow/util/logging.h>
+#include <arrow/util/uri.h>
+
+#include "cudf-flight-ucx.h"
+#include "ucx_server.h"
+
+namespace flight = arrow::flight;
+namespace ipc = arrow::ipc;
+
+cudf::column_metadata column_info_to_metadata(const
cudf::io::column_name_info& info) {
+ cudf::column_metadata result;
+ result.name = info.name;
+ std::transform(info.children.begin(), info.children.end(),
+ std::back_inserter(result.children_meta),
column_info_to_metadata);
+ return result;
+}
+
+std::vector<cudf::column_metadata> table_metadata_to_column(
+ const cudf::io::table_metadata& tbl_meta) {
+ std::vector<cudf::column_metadata> result;
+
+ std::transform(tbl_meta.schema_info.begin(), tbl_meta.schema_info.end(),
+ std::back_inserter(result), column_info_to_metadata);
+ return result;
+}
+
+// a UCX server which serves cuda record batches via the dissociated ipc
protocol
+class CudaUcxServer : public UcxServer {
+ public:
+ CudaUcxServer() {
+ // create a buffer holding 8 bytes on the GPU to use for padding buffers
+ cuda_padding_bytes_ = rmm::device_buffer(8, rmm::cuda_stream_view{});
+ cuMemsetD8(reinterpret_cast<uintptr_t>(cuda_padding_bytes_.data()), 0, 8);
+ }
+ virtual ~CudaUcxServer() {
+ if (listening_.load()) {
+ ARROW_UNUSED(Shutdown());
+ }
+ }
+
+ arrow::Status initialize() {
+ // load the parquet data directly onto the GPU as a libcudf table
+ auto source = cudf::io::source_info("./data/taxi-data/train.parquet");
+ auto options = cudf::io::parquet_reader_options::builder(source);
+ cudf::io::chunked_parquet_reader rdr(1 * 1024 * 1024, options);
+
+ // get arrow::RecordBatches for each chunk of the parquet data while
+ // leaving the data on the GPU
+ arrow::RecordBatchVector batches;
+ auto chunk = rdr.read_chunk();
+ auto schema = cudf::to_arrow_schema(chunk.tbl->view(),
+
table_metadata_to_column(chunk.metadata));
+ auto device_out = cudf::to_arrow_device(std::move(*chunk.tbl));
+ ARROW_ASSIGN_OR_RAISE(auto data,
+ arrow::ImportDeviceRecordBatch(device_out.get(),
schema.get()));
+
+ batches.push_back(std::move(data));
+
+ while (rdr.has_next()) {
+ chunk = rdr.read_chunk();
+ device_out = cudf::to_arrow_device(std::move(*chunk.tbl));
+ ARROW_ASSIGN_OR_RAISE(
+ data, arrow::ImportDeviceRecordBatch(device_out.get(),
schema.get()));
+ batches.push_back(std::move(data));
+ }
+
+ data_sets_.emplace("train.parquet", std::move(batches));
+
+ // initialize the server and let it choose its own port
+ ARROW_RETURN_NOT_OK(Init("127.0.0.1", 0));
+
+ ARROW_ASSIGN_OR_RAISE(ctrl_location_,
+ flight::Location::Parse(location_.ToString() +
"?want_data=" +
+
std::to_string(kWantCtrlTag)));
+ ARROW_ASSIGN_OR_RAISE(data_location_,
+ flight::Location::Parse(location_.ToString() +
"?want_data=" +
+
std::to_string(kWantDataTag)));
+ return arrow::Status::OK();
+ }
+
+ inline flight::Location ctrl_location() const { return ctrl_location_; }
+ inline flight::Location data_location() const { return data_location_; }
+
+ protected:
+ arrow::Status setup_handlers(UcxServer::ClientWorker* worker) override {
+ return arrow::Status::OK();
+ }
+
+ arrow::Status do_work(UcxServer::ClientWorker* worker) override {
+ // probe for a message with the want_data tag synchronously,
+ // so this will block until it receives a message with this tag
+ ARROW_ASSIGN_OR_RAISE(
+ auto tag_info, worker->conn_->ProbeForTagSync(kWantDataTag,
~kWantCtrlMask, 1));
+
+ std::string msg;
+ msg.resize(tag_info.first.length);
+ ARROW_RETURN_NOT_OK(
+ worker->conn_->RecvTagData(tag_info.second,
reinterpret_cast<void*>(msg.data()),
+ msg.size(), nullptr, nullptr,
UCS_MEMORY_TYPE_HOST));
+
+ ARROW_LOG(DEBUG) << "server received WantData: " << msg;
+
+ // simulate two separate servers, one metadata server and one body data
server
+ if (tag_info.first.sender_tag & kWantCtrlMask) {
+ return send_metadata_stream(worker, msg);
+ }
+
+ return send_data_stream(worker, msg);
+ }
+
+ private:
+ arrow::Status send_metadata_stream(UcxServer::ClientWorker* worker,
+ const std::string& ident) {
+ auto it = data_sets_.find(ident);
+ if (it == data_sets_.end()) {
+ return arrow::Status::Invalid("data set not found:", ident);
+ }
+
+ ipc::IpcWriteOptions ipc_options;
+ ipc::DictionaryFieldMapper mapper;
+ const auto& record_list = it->second;
+ auto schema = record_list[0]->schema();
+ ARROW_RETURN_NOT_OK(mapper.AddSchemaFields(*schema));
+
+ // for each record in the stream, collect the IPC metadata to send
+ uint32_t sequence_num = 0;
+ // schema payload is first
+ ipc::IpcPayload payload;
+ ARROW_RETURN_NOT_OK(ipc::GetSchemaPayload(*schema, ipc_options, mapper,
&payload));
+ ARROW_RETURN_NOT_OK(write_ipc_metadata(worker->conn_.get(), payload,
sequence_num++));
+
+ // then any dictionaries
+ ARROW_ASSIGN_OR_RAISE(const auto dictionaries,
+ ipc::CollectDictionaries(*record_list[0], mapper));
+ for (const auto& pair : dictionaries) {
+ ARROW_RETURN_NOT_OK(
+ ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options,
&payload));
+ ARROW_RETURN_NOT_OK(
+ write_ipc_metadata(worker->conn_.get(), payload, sequence_num++));
+ }
+
+ // finally the record batch metadata messages
+ for (const auto& batch : record_list) {
+ ARROW_RETURN_NOT_OK(ipc::GetRecordBatchPayload(*batch, ipc_options,
&payload));
+ ARROW_RETURN_NOT_OK(
+ write_ipc_metadata(worker->conn_.get(), payload, sequence_num++));
+ }
+
+ // finally, we send the End-Of-Stream message
+ std::array<uint8_t, 5>
eos_bytes{static_cast<uint8_t>(MetadataMsgType::EOS), 0, 0, 0,
+ 0};
+ utils::Uint32ToBytesLE(sequence_num, eos_bytes.data() + 1);
+
+ ARROW_RETURN_NOT_OK(worker->conn_->Flush());
+ return worker->conn_->SendAM(0, eos_bytes.data(), eos_bytes.size());
+ }
+
+ struct PendingIOV {
+ std::vector<ucp_dt_iov_t> iovs;
+ arrow::BufferVector body_buffers;
+ };
+
+ arrow::Status write_ipc_metadata(utils::Connection* cnxn,
+ const ipc::IpcPayload& payload,
+ const uint32_t sequence_num) {
+ // our metadata messages are always CPU host memory
+ ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_HOST;
+
+ // construct our 5 byte prefix, the message type followed by the sequence
number
+ auto pending = std::make_unique<PendingIOV>();
+ pending->iovs.resize(2);
+ pending->iovs[0].buffer = malloc(5);
+ pending->iovs[0].length = 5;
+ reinterpret_cast<uint8_t*>(pending->iovs[0].buffer)[0] =
+ static_cast<uint8_t>(MetadataMsgType::METADATA);
+ utils::Uint32ToBytesLE(sequence_num,
+ reinterpret_cast<uint8_t*>(pending->iovs[0].buffer)
+ 1);
+
+ // after the prefix, we add the metadata we want to send
+ pending->iovs[1].buffer =
const_cast<void*>(payload.metadata->data_as<void>());
+ pending->iovs[1].length = payload.metadata->size();
+ pending->body_buffers.emplace_back(payload.metadata);
+
+ auto* pending_iov = pending.get();
+ void* user_data = pending.release();
+ return cnxn->SendAMIov(
+ 0, pending_iov->iovs.data(), pending_iov->iovs.size(), user_data,
+ [](void* request, ucs_status_t status, void* user_data) {
+ auto pending_iov =
+
std::unique_ptr<PendingIOV>(reinterpret_cast<PendingIOV*>(user_data));
+ if (request) ucp_request_free(request);
+ if (status != UCS_OK) {
+ ARROW_LOG(ERROR)
+ << utils::FromUcsStatus("ucp_am_send_nbx_cb",
status).ToString();
+ }
+ free(pending_iov->iovs[0].buffer);
+ },
+ mem_type);
+ }
+
+ arrow::Status send_data_stream(UcxServer::ClientWorker* worker,
+ const std::string& ident) {
+ auto it = data_sets_.find(ident);
+ if (it == data_sets_.end()) {
+ return arrow::Status::Invalid("data set not found:", ident);
+ }
+
+ ipc::IpcWriteOptions ipc_options;
+ ipc::DictionaryFieldMapper mapper;
+ const auto& record_list = it->second;
+ auto schema = record_list[0]->schema();
+ ARROW_RETURN_NOT_OK(mapper.AddSchemaFields(*schema));
+
+ // start at 1 since schema payload has no body
+ uint32_t sequence_num = 1;
+
+ ipc::IpcPayload payload;
+ ARROW_ASSIGN_OR_RAISE(const auto dictionaries,
+ ipc::CollectDictionaries(*record_list[0], mapper));
+ for (const auto& pair : dictionaries) {
+ ARROW_RETURN_NOT_OK(
+ ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options,
&payload));
+ ARROW_RETURN_NOT_OK(write_ipc_body(worker->conn_.get(), payload,
sequence_num++));
+ }
+
+ for (const auto& batch : record_list) {
+ ARROW_RETURN_NOT_OK(ipc::GetRecordBatchPayload(*batch, ipc_options,
&payload));
+ ARROW_RETURN_NOT_OK(write_ipc_body(worker->conn_.get(), payload,
sequence_num++));
+ }
+
+ return worker->conn_->Flush();
+ }
+
+ arrow::Status write_ipc_body(utils::Connection* cnxn, const ipc::IpcPayload&
payload,
+ const uint32_t sequence_num) {
+ ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_CUDA;
+
+ // determine the number of buffers and padding we need along with the
total size
+ auto pending = std::make_unique<PendingIOV>();
+ int32_t total_buffers = 0;
+ for (const auto& buffer : payload.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+ if (buffer->is_cpu()) {
+ mem_type = UCS_MEMORY_TYPE_HOST;
+ }
+ total_buffers++;
+ // arrow ipc requires aligning buffers to 8 byte boundary
+ const auto remainder = static_cast<int>(
+ arrow::bit_util::RoundUpToMultipleOf8(buffer->size()) -
buffer->size());
+ if (remainder) total_buffers++;
+ }
+
+ pending->iovs.resize(total_buffers);
+ // we'll use scatter-gather to avoid extra copies
+ ucp_dt_iov_t* iov = pending->iovs.data();
+ pending->body_buffers = payload.body_buffers;
+
+ void* padding_bytes =
+ const_cast<void*>(reinterpret_cast<const
void*>(padding_bytes_.data()));
+ if (mem_type == UCS_MEMORY_TYPE_CUDA) {
+ padding_bytes = cuda_padding_bytes_.data();
+ }
+
+ for (const auto& buffer : payload.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+ // for cuda memory, buffer->address() will return a device pointer
+ iov->buffer = const_cast<void*>(reinterpret_cast<const
void*>(buffer->address()));
+ iov->length = buffer->size();
+ ++iov;
+
+ const auto remainder = static_cast<int>(
+ arrow::bit_util::RoundUpToMultipleOf8(buffer->size()) -
buffer->size());
+ if (remainder) {
+ iov->buffer = padding_bytes;
+ iov->length = remainder;
+ ++iov;
+ }
+ }
+
+ auto pending_iov = pending.release();
+ // indicate that we're sending the full data body, not a pointer list and
add
+ // the sequence number to the tag
+ ucp_tag_t tag =
+ (uint64_t(0) << kShiftBodyType) |
arrow::bit_util::ToLittleEndian(sequence_num);
+ return cnxn->SendTagIov(
+ tag, pending_iov->iovs.data(), pending_iov->iovs.size(), pending_iov,
+ [](void* request, ucs_status_t status, void* user_data) {
+ auto pending_iov =
+
std::unique_ptr<PendingIOV>(reinterpret_cast<PendingIOV*>(user_data));
+ if (status != UCS_OK) {
+ ARROW_LOG(ERROR)
+ << utils::FromUcsStatus("ucp_tag_send_nbx_cb",
status).ToString();
+ }
+ if (request) {
+ ucp_request_free(request);
+ }
+ },
+ mem_type);
+ }
+
+ flight::Location ctrl_location_;
+ flight::Location data_location_;
+ std::unordered_map<std::string, arrow::RecordBatchVector> data_sets_;
+
+ rmm::device_buffer cuda_padding_bytes_;
+ const std::array<uint8_t, 8> padding_bytes_{0, 0, 0, 0, 0, 0, 0, 0};
+};
+
+// a flight server that will serve up the flight-info with a ucx uri to point
+// to the cuda ucx server
+class CudaFlightServer : public flight::FlightServerBase {
+ public:
+ CudaFlightServer(flight::Location ctrl_server_loc, flight::Location
data_server_loc)
+ : FlightServerBase(),
+ ctrl_server_loc_{std::move(ctrl_server_loc)},
+ data_server_loc_{std::move(data_server_loc)} {}
+ ~CudaFlightServer() override {}
+
+ inline void register_schema(std::string cmd, std::shared_ptr<arrow::Schema>
schema) {
+ schema_reg_.emplace(std::move(cmd), std::move(schema));
+ }
+
+ arrow::Status GetFlightInfo(const flight::ServerCallContext& context,
+ const flight::FlightDescriptor& request,
+ std::unique_ptr<flight::FlightInfo>* info)
override {
+ flight::FlightEndpoint endpoint{
+ {request.cmd}, {ctrl_server_loc_, data_server_loc_}, std::nullopt, {}};
+
+ auto it = schema_reg_.find(request.cmd);
+ if (it == schema_reg_.end() || !it->second) {
+ return arrow::Status::Invalid("could not find schema for ", request.cmd);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto flightinfo,
+ flight::FlightInfo::Make(*it->second, request, {endpoint}, -1, -1,
false));
+ *info = std::make_unique<flight::FlightInfo>(std::move(flightinfo));
+ return arrow::Status::OK();
+ }
+
+ private:
+ std::unordered_map<std::string, std::shared_ptr<arrow::Schema>> schema_reg_;
+ flight::Location ctrl_server_loc_;
+ flight::Location data_server_loc_;
+};
+
+arrow::Status run_server(const std::string& addr, const int port) {
+ CudaUcxServer server;
+ ARROW_ASSIGN_OR_RAISE(auto mgr, arrow::cuda::CudaDeviceManager::Instance());
+ ARROW_ASSIGN_OR_RAISE(auto ctx, mgr->GetContext(0));
+ server.set_cuda_context(ctx);
+ ARROW_RETURN_NOT_OK(server.initialize());
+
+ flight::Location ctrl_server_location = server.ctrl_location();
+ flight::Location data_server_location = server.data_location();
+
+ std::shared_ptr<arrow::Schema> sc;
+ {
+ auto source = cudf::io::source_info("./data/taxi-data/train.parquet");
+ auto options =
cudf::io::parquet_reader_options::builder(source).num_rows(1);
+ auto result = cudf::io::read_parquet(options);
+
+ auto schema = cudf::to_arrow_schema(result.tbl->view(),
+
table_metadata_to_column(result.metadata));
+ ARROW_ASSIGN_OR_RAISE(sc, arrow::ImportSchema(schema.get()));
+ }
+
+ auto flight_server = std::make_shared<CudaFlightServer>(
+ std::move(ctrl_server_location), std::move(data_server_location));
+ flight_server->register_schema("train.parquet", std::move(sc));
+
+ ARROW_ASSIGN_OR_RAISE(auto loc, flight::Location::ForGrpcTcp(addr, port));
+ flight::FlightServerOptions options(loc);
+
+ RETURN_NOT_OK(flight_server->Init(options));
+ RETURN_NOT_OK(flight_server->SetShutdownOnSignals({SIGTERM}));
+
+ std::cout << "Flight Server Listening on " <<
flight_server->location().ToString()
+ << std::endl;
+
+ return flight_server->Serve();
+}
\ No newline at end of file
diff --git a/dissociated-ipc/cudf-flight-ucx.cc
b/dissociated-ipc/cudf-flight-ucx.cc
new file mode 100644
index 0000000..478b301
--- /dev/null
+++ b/dissociated-ipc/cudf-flight-ucx.cc
@@ -0,0 +1,39 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/status.h>
+#include <arrow/util/logging.h>
+#include <arrow/util/uri.h>
+#include <gflags/gflags.h>
+
+#include "cudf-flight-ucx.h"
+
+DEFINE_int32(port, 31337, "port to listen or connect");
+DEFINE_string(address, "127.0.0.1", "address to connect to");
+DEFINE_bool(client, false, "run the client");
+
+int main(int argc, char** argv) {
+ arrow::util::ArrowLog::StartArrowLog("cudf-flight-poc",
+
arrow::util::ArrowLogLevel::ARROW_DEBUG);
+
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+ if (FLAGS_client) {
+ ARROW_CHECK_OK(run_client(FLAGS_address, FLAGS_port));
+ } else {
+ ARROW_CHECK_OK(run_server(FLAGS_address, FLAGS_port));
+ }
+}
diff --git a/dissociated-ipc/cudf-flight-ucx.h
b/dissociated-ipc/cudf-flight-ucx.h
new file mode 100644
index 0000000..c260bd2
--- /dev/null
+++ b/dissociated-ipc/cudf-flight-ucx.h
@@ -0,0 +1,38 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <ucp/api/ucp.h>
+
+// Define some constants for the `want_data` tags
+static constexpr ucp_tag_t kWantDataTag = 0x00000DEADBA0BAB0;
+static constexpr ucp_tag_t kWantCtrlTag = 0xFFFFFDEADBA0BAB0;
+// define a mask to check the tag
+static constexpr ucp_tag_t kWantCtrlMask = 0xFFFFF00000000000;
+
+// constant for the bit shift to make the data body type the most
+// significant byte
+static constexpr int kShiftBodyType = 55;
+
+enum class MetadataMsgType : uint8_t {
+ EOS = 0,
+ METADATA = 1,
+};
+
+arrow::Status run_server(const std::string& addr, const int port);
+arrow::Status run_client(const std::string& addr, const int port);
\ No newline at end of file
diff --git a/dissociated-ipc/ucx_client.cc b/dissociated-ipc/ucx_client.cc
new file mode 100644
index 0000000..91bf140
--- /dev/null
+++ b/dissociated-ipc/ucx_client.cc
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "ucx_client.h"
+#include "ucx_utils.h"
+
+#include <memory>
+#include <string>
+
+arrow::Status UcxClient::Init(const std::string& host, const int32_t port) {
+ ucp_config_t* ucp_config;
+ ucp_params_t ucp_params;
+ ucs_status_t status;
+
+ status = ucp_config_read(nullptr, nullptr, &ucp_config);
+ ARROW_RETURN_NOT_OK(utils::FromUcsStatus("ucp_config_read", status));
+
+ // if location is IPv6 must adjust UCX config
+ // we assume locations always resolve to IPv6 or IPv4
+ // but that's not necessarily true
+ ARROW_ASSIGN_OR_RAISE(addrlen_, utils::to_sockaddr(host, port,
&connect_addr_));
+ if (connect_addr_.ss_family == AF_INET6) {
+ ARROW_RETURN_NOT_OK(utils::FromUcsStatus(
+ "ucp_config_modify", ucp_config_modify(ucp_config, "AF_PRIO",
"inet6")));
+ }
+
+ std::memset(&ucp_params, 0, sizeof(ucp_params));
+ ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES;
+ ucp_params.features = UCP_FEATURE_WAKEUP | UCP_FEATURE_AM | UCP_FEATURE_RMA |
+ UCP_FEATURE_STREAM | UCP_FEATURE_TAG;
+
+ ucp_context_h ucp_context;
+ status = ucp_init(&ucp_params, ucp_config, &ucp_context);
+ ucp_config_release(ucp_config);
+
+ ARROW_RETURN_NOT_OK(utils::FromUcsStatus("ucp_init", status));
+ ucp_context_.reset(new utils::UcpContext(ucp_context));
+ return arrow::Status::OK();
+}
+
+arrow::Result<std::unique_ptr<utils::Connection>> UcxClient::CreateConn() {
+ ucp_worker_params_t worker_params;
+ std::memset(&worker_params, 0, sizeof(worker_params));
+ worker_params.field_mask =
+ UCP_WORKER_PARAM_FIELD_THREAD_MODE | UCP_WORKER_PARAM_FIELD_FLAGS;
+ worker_params.thread_mode = UCS_THREAD_MODE_SERIALIZED;
+ worker_params.flags = UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK;
+
+ ucp_worker_h ucp_worker;
+ ucs_status_t status =
+ ucp_worker_create(ucp_context_->get(), &worker_params, &ucp_worker);
+ ARROW_RETURN_NOT_OK(utils::FromUcsStatus("ucp_worker_create", status));
+
+ auto cnxn = std::make_unique<utils::Connection>(
+ std::make_shared<utils::UcpWorker>(ucp_context_, ucp_worker));
+ ARROW_RETURN_NOT_OK(cnxn->CreateEndpoint(connect_addr_, addrlen_));
+
+ return cnxn;
+}
diff --git a/dissociated-ipc/ucx_client.h b/dissociated-ipc/ucx_client.h
new file mode 100644
index 0000000..ae8a206
--- /dev/null
+++ b/dissociated-ipc/ucx_client.h
@@ -0,0 +1,40 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "ucx_conn.h"
+#include "ucx_utils.h"
+
+#include <memory>
+#include <string>
+
+#include <arrow/status.h>
+
+class UcxClient {
+ public:
+ UcxClient() = default;
+ ~UcxClient() = default;
+
+ arrow::Status Init(const std::string& host, const int32_t port);
+ arrow::Result<std::unique_ptr<utils::Connection>> CreateConn();
+
+ private:
+ std::shared_ptr<utils::UcpContext> ucp_context_;
+ struct sockaddr_storage connect_addr_;
+ size_t addrlen_;
+};
diff --git a/dissociated-ipc/ucx_conn.cc b/dissociated-ipc/ucx_conn.cc
new file mode 100644
index 0000000..e07d66b
--- /dev/null
+++ b/dissociated-ipc/ucx_conn.cc
@@ -0,0 +1,355 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "ucx_conn.h"
+
+#include <limits>
+#include <string>
+
+#include <arrow/device.h>
+
+namespace utils {
+
+ucs_status_t wait_for_request(ucs_status_ptr_t request, UcpWorker& worker) {
+ ucs_status_t status = UCS_OK;
+ if (UCS_PTR_IS_ERR(request)) {
+ status = UCS_PTR_STATUS(request);
+ } else if (UCS_PTR_IS_PTR(request)) {
+ while ((status = ucp_request_check_status(request)) == UCS_INPROGRESS) {
+ ucp_worker_progress(worker.get());
+ }
+ ucp_request_free(request);
+ } else {
+ DCHECK(!request);
+ }
+ return status;
+}
+
+Connection::Connection(std::shared_ptr<UcpWorker> worker)
+ : ucp_worker_{std::move(worker)} {}
+
+Connection::Connection(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint)
+ : ucp_worker_{std::move(worker)}, remote_endpoint_(endpoint) {}
+
+arrow::Status Connection::CreateEndpoint(ucp_conn_request_h request) {
+ ucs_status_t status;
+ ucp_ep_params_t params;
+ std::memset(¶ms, 0, sizeof(params));
+ params.field_mask = UCP_EP_PARAM_FIELD_CONN_REQUEST |
UCP_EP_PARAM_FIELD_ERR_HANDLER;
+ params.err_handler.arg = this;
+ params.err_handler.cb = Connection::err_cb;
+ params.conn_request = request;
+
+ return FromUcsStatus("ucp_ep_create",
+ ucp_ep_create(ucp_worker_->get(), ¶ms,
&remote_endpoint_));
+}
+
+arrow::Status Connection::CreateEndpoint(const sockaddr_storage& connect_addr,
+ const size_t addrlen) {
+ std::string peer;
+ ARROW_UNUSED(SockaddrToString(connect_addr).Value(&peer));
+ ARROW_LOG(DEBUG) << "Connecting to " << peer;
+
+ ucp_ep_params_t params;
+ params.field_mask =
+ UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_NAME |
UCP_EP_PARAM_FIELD_SOCK_ADDR;
+ params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER |
UCP_EP_PARAMS_FLAGS_SEND_CLIENT_ID;
+ params.name = "UcxConn";
+ params.sockaddr.addr = reinterpret_cast<const sockaddr*>(&connect_addr);
+ params.sockaddr.addrlen = addrlen;
+
+ return FromUcsStatus("ucp_ep_create",
+ ucp_ep_create(ucp_worker_->get(), ¶ms,
&remote_endpoint_));
+}
+
+arrow::Status Connection::Flush() {
+ ARROW_RETURN_NOT_OK(CheckClosed());
+
+ ucp_request_param_t param;
+ param.op_attr_mask = 0;
+ void* request = ucp_ep_flush_nbx(remote_endpoint_, ¶m);
+ if (!request) {
+ return arrow::Status::OK();
+ }
+
+ return utils::FromUcsStatus("ucp_ep_flush_nbx",
+ wait_for_request(request, *ucp_worker_));
+}
+
+arrow::Status Connection::Close() {
+ ucp_request_param_t params;
+ std::memset(¶ms, 0, sizeof(ucp_request_param_t));
+ params.flags = UCP_EP_CLOSE_FLAG_FORCE;
+
+ void* request = ucp_ep_close_nbx(remote_endpoint_, ¶ms);
+ auto status = wait_for_request(request, *ucp_worker_);
+
+ remote_endpoint_ = nullptr;
+ ucp_worker_.reset();
+ if (status != UCS_OK && !is_ignorable_disconnect_error(status)) {
+ return FromUcsStatus("close conn", status);
+ }
+ return arrow::Status::OK();
+}
+
+arrow::Status Connection::SetAMHandler(unsigned int id, void* user_data,
+ ucp_am_recv_callback_t cb) {
+ ucp_am_handler_param_t params;
+ params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID;
+ params.id = id;
+ if (user_data) {
+ params.field_mask |= UCP_AM_HANDLER_PARAM_FIELD_ARG;
+ params.arg = user_data;
+ }
+ if (cb) {
+ params.field_mask |= UCP_AM_HANDLER_PARAM_FIELD_CB;
+ params.cb = cb;
+ }
+ return utils::FromUcsStatus(
+ "ucp_worker_set_am_recv_handler",
+ ucp_worker_set_am_recv_handler(ucp_worker_->get(), ¶ms));
+}
+
+arrow::Result<std::pair<ucp_tag_recv_info_t, ucp_tag_message_h>>
Connection::ProbeForTag(
+ ucp_tag_t tag, ucp_tag_t mask, int remove) {
+ ARROW_RETURN_NOT_OK(CheckClosed());
+
+ ucp_tag_recv_info_t info_tag;
+ auto msg_tag = ucp_tag_probe_nb(ucp_worker_->get(), tag, mask, remove,
&info_tag);
+ return std::make_pair(info_tag, msg_tag);
+}
+
+arrow::Result<std::pair<ucp_tag_recv_info_t, ucp_tag_message_h>>
+Connection::ProbeForTagSync(ucp_tag_t tag, ucp_tag_t mask, int remove) {
+ ARROW_RETURN_NOT_OK(CheckClosed());
+
+ ucp_tag_recv_info_t info_tag;
+ ucp_tag_message_h msg_tag;
+ while (true) {
+ msg_tag = ucp_tag_probe_nb(ucp_worker_->get(), tag, mask, remove,
&info_tag);
+ if (msg_tag != nullptr) {
+ // success
+ break;
+ } else if (ucp_worker_progress(ucp_worker_->get())) {
+ // some events polled, try again
+ continue;
+ }
+
+ // ucp_worker_progress 0, so we sleep
+ // following blocked method used to poll internal file descriptor
+ // to make CPU idle and not spin loop
+ ARROW_RETURN_NOT_OK(
+ FromUcsStatus("ucp_worker_wait", ucp_worker_wait(ucp_worker_->get())));
+ }
+
+ return std::make_pair(info_tag, msg_tag);
+}
+
+struct RndvPromiseBuffer {
+ std::promise<std::unique_ptr<arrow::Buffer>> p;
+ std::unique_ptr<arrow::Buffer> buf;
+};
+
+ucs_status_t Connection::RecvAM(std::promise<std::unique_ptr<arrow::Buffer>> p,
+ const void* header, const size_t header_length,
+ void* data, const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ if (data_length > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
+ ARROW_LOG(ERROR) << "cannot allocate buffer greater than 2 GiB, requested:
"
+ << data_length;
+ return UCS_ERR_IO_ERROR;
+ }
+
+ if (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) {
+ // data provided can be held by us. return UCS_INPROGRESS to make the data
persist
+ // and we will eventually use ucp_am_data_release to release it.
+ auto buffer = std::make_unique<UcxDataBuffer>(ucp_worker_, data,
data_length);
+ p.set_value(std::move(buffer));
+ return UCS_INPROGRESS;
+ }
+
+ // rendezvous protocol
+ if (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV) {
+ auto maybe_buffer =
arrow::default_cpu_memory_manager()->AllocateBuffer(data_length);
+ if (!maybe_buffer.ok()) {
+ ARROW_LOG(ERROR) << "could not allocate buffer for message: "
+ << maybe_buffer.status().ToString();
+ return UCS_ERR_NO_MEMORY;
+ }
+
+ auto buffer = maybe_buffer.MoveValueUnsafe();
+ void* dest = reinterpret_cast<void*>(buffer->mutable_address());
+
+ ucp_request_param_t recv_param;
+ recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_MEMORY_TYPE |
+ UCP_OP_ATTR_FIELD_USER_DATA |
UCP_OP_ATTR_FLAG_NO_IMM_CMPL;
+ recv_param.memory_type = UCS_MEMORY_TYPE_HOST;
+ recv_param.user_data = new RndvPromiseBuffer{std::move(p),
std::move(buffer)};
+ recv_param.cb.recv_am = [](void* request, ucs_status_t status, size_t
length,
+ void* user_data) {
+ auto p = std::unique_ptr<RndvPromiseBuffer>(
+ reinterpret_cast<RndvPromiseBuffer*>(user_data));
+ if (request) {
+ ucp_request_free(request);
+ }
+ if (status == UCS_OK) {
+ p->p.set_value(std::move(p->buf));
+ } else {
+ ARROW_LOG(ERROR) << FromUcsStatus("ucp_am_recv_data_nbx cb",
status).ToString();
+ p->p.set_value(nullptr);
+ }
+ };
+ void* request =
+ ucp_am_recv_data_nbx(ucp_worker_->get(), data, dest, data_length,
&recv_param);
+ if (UCS_PTR_IS_ERR(request)) {
+ return UCS_PTR_STATUS(request);
+ }
+ return UCS_OK;
+ }
+
+ auto maybe_buffer =
arrow::default_cpu_memory_manager()->AllocateBuffer(data_length);
+ if (!maybe_buffer.ok()) {
+ ARROW_LOG(ERROR) << "could not allocate buffer for message: "
+ << maybe_buffer.status().ToString();
+ return UCS_ERR_NO_MEMORY;
+ }
+ auto buffer = maybe_buffer.MoveValueUnsafe();
+ std::memcpy(buffer->mutable_data(), data, data_length);
+ p.set_value(std::move(buffer));
+ return UCS_OK;
+}
+
+arrow::Status Connection::RecvTagData(ucp_tag_message_h msg, void* buffer,
+ const size_t count, void* user_data,
+ ucp_tag_recv_nbx_callback_t cb,
+ const ucs_memory_type_t memory_type) {
+ ARROW_RETURN_NOT_OK(CheckClosed());
+
+ ucp_request_param_t recv_param;
+ recv_param.op_attr_mask = UCP_OP_ATTR_FLAG_NO_IMM_CMPL |
UCP_OP_ATTR_FIELD_DATATYPE |
+ UCP_OP_ATTR_FIELD_MEMORY_TYPE;
+ recv_param.datatype = ucp_dt_make_contig(1);
+ recv_param.memory_type = memory_type;
+ if (user_data) {
+ recv_param.user_data = user_data;
+ recv_param.op_attr_mask |= UCP_OP_ATTR_FIELD_USER_DATA;
+ }
+ if (cb) {
+ recv_param.cb.recv = cb;
+ recv_param.op_attr_mask |= UCP_OP_ATTR_FIELD_CALLBACK;
+ }
+
+ auto request =
+ ucp_tag_msg_recv_nbx(ucp_worker_->get(), buffer, count, msg,
&recv_param);
+ return FromUcsStatus("recvtagdata", wait_for_request(request, *ucp_worker_));
+}
+
+arrow::Status Connection::SendAM(unsigned int id, const void* data, const
int64_t size) {
+ ARROW_RETURN_NOT_OK(CheckClosed());
+
+ ucp_request_param_t request_param;
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS;
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ auto request =
+ ucp_am_send_nbx(remote_endpoint_, id, nullptr, 0, data, size,
&request_param);
+ return FromUcsStatus("ucp_am_send_nbx", wait_for_request(request,
*ucp_worker_));
+}
+
+arrow::Status Connection::SendAMIov(unsigned int id, const ucp_dt_iov_t* iov,
+ const size_t iov_cnt, void* user_data,
+ ucp_send_nbx_callback_t cb,
+ const ucs_memory_type_t memory_type) {
+ ARROW_RETURN_NOT_OK(CheckClosed());
+
+ ucp_request_param_t request_param;
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS |
UCP_OP_ATTR_FIELD_DATATYPE |
+ UCP_OP_ATTR_FIELD_MEMORY_TYPE;
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+ request_param.datatype = UCP_DATATYPE_IOV;
+ if (cb) {
+ request_param.cb.send = cb;
+ request_param.op_attr_mask |= UCP_OP_ATTR_FIELD_CALLBACK;
+ }
+ if (user_data) {
+ request_param.user_data = user_data;
+ request_param.op_attr_mask |= UCP_OP_ATTR_FIELD_USER_DATA;
+ }
+ request_param.memory_type = memory_type;
+
+ void* request =
+ ucp_am_send_nbx(remote_endpoint_, id, nullptr, 0, iov, iov_cnt,
&request_param);
+ if (!request) {
+ // request completed immediately, call the cb manually if it exists
+ // since it won't be called automatically
+ if (cb) cb(request, UCS_OK, user_data);
+ } else if (UCS_PTR_IS_ERR(request)) {
+ // same thing, call it manually
+ auto status = UCS_PTR_STATUS(request);
+ if (cb) cb(request, status, user_data);
+ return utils::FromUcsStatus("ucp_am_send_nbx", status);
+ }
+
+ // otherwise the callback will be called eventually when it completes
+ // we can just return success.
+ return arrow::Status::OK();
+}
+
+arrow::Status Connection::SendTagIov(ucp_tag_t tag, const ucp_dt_iov_t* iov,
+ const size_t iov_cnt, void* user_data,
+ ucp_send_nbx_callback_t cb,
+ const ucs_memory_type_t memory_type) {
+ ARROW_RETURN_NOT_OK(CheckClosed());
+
+ ucp_request_param_t request_param;
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE |
UCP_OP_ATTR_FIELD_MEMORY_TYPE;
+ request_param.datatype = UCP_DATATYPE_IOV;
+ if (cb) {
+ request_param.cb.send = cb;
+ request_param.op_attr_mask |= UCP_OP_ATTR_FIELD_CALLBACK;
+ }
+ if (user_data) {
+ request_param.user_data = user_data;
+ request_param.op_attr_mask |= UCP_OP_ATTR_FIELD_USER_DATA;
+ }
+ request_param.memory_type = memory_type;
+
+ void* request = ucp_tag_send_nbx(remote_endpoint_, iov, iov_cnt, tag,
&request_param);
+ if (!request) {
+ // request completed immediately, call the cb manually if it exists
+ // since it won't be called automatically
+ if (cb) cb(request, UCS_OK, user_data);
+ } else if (UCS_PTR_IS_ERR(request)) {
+ // same thing, call it manually
+ auto status = UCS_PTR_STATUS(request);
+ if (cb) cb(request, status, user_data);
+ return utils::FromUcsStatus("ucp_tag_send_nbx", status);
+ }
+
+ return arrow::Status::OK();
+}
+
+arrow::Status Connection::SendTagSync(ucp_tag_t tag, const void* buffer,
+ const size_t count) {
+ ARROW_RETURN_NOT_OK(CheckClosed());
+
+ ucp_request_param_t request_param;
+ ucs_status_ptr_t request =
+ ucp_tag_send_sync_nbx(remote_endpoint_, buffer, count, tag,
&request_param);
+ return FromUcsStatus("ucp_tag_send_sync_nbx", wait_for_request(request,
*ucp_worker_));
+}
+} // namespace utils
diff --git a/dissociated-ipc/ucx_conn.h b/dissociated-ipc/ucx_conn.h
new file mode 100644
index 0000000..ecaa6fc
--- /dev/null
+++ b/dissociated-ipc/ucx_conn.h
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <ucp/api/ucp.h>
+#include <future>
+#include <memory>
+#include <utility>
+
+#include "arrow/util/logging.h"
+#include "ucx_utils.h"
+
+namespace utils {
+class Connection {
+ public:
+ explicit Connection(std::shared_ptr<UcpWorker> worker);
+ Connection(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint);
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Connection);
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(Connection);
+ ~Connection() { DCHECK(!ucp_worker_) << "Connection was not closed!"; }
+
+ arrow::Status CreateEndpoint(ucp_conn_request_h request);
+ arrow::Status CreateEndpoint(const sockaddr_storage& addr, const size_t
addrlen);
+ arrow::Status Flush();
+ arrow::Status Close();
+ inline bool is_closed() const { return closed_; }
+ inline unsigned int Progress() { return
ucp_worker_progress(ucp_worker_->get()); }
+ inline ucs_status_t WorkerWait() { return
ucp_worker_wait(ucp_worker_->get()); }
+
+ arrow::Status SetAMHandler(unsigned int id, void* user_data,
ucp_am_recv_callback_t cb);
+
+ arrow::Result<std::pair<ucp_tag_recv_info_t, ucp_tag_message_h>> ProbeForTag(
+ ucp_tag_t tag, ucp_tag_t mask, int remove);
+ arrow::Result<std::pair<ucp_tag_recv_info_t, ucp_tag_message_h>>
ProbeForTagSync(
+ ucp_tag_t tag, ucp_tag_t mask, int remove);
+ arrow::Status RecvTagData(ucp_tag_message_h msg, void* buffer, const size_t
count,
+ void* user_data, ucp_tag_recv_nbx_callback_t cb,
+ const ucs_memory_type_t memory_type);
+ ucs_status_t RecvAM(std::promise<std::unique_ptr<arrow::Buffer>> p, const
void* header,
+ const size_t header_length, void* data, const size_t
data_length,
+ const ucp_am_recv_param_t* param);
+
+ arrow::Status SendAM(unsigned int id, const void* data, const int64_t size);
+ arrow::Status SendAMIov(unsigned int id, const ucp_dt_iov_t* iov, const
size_t iov_cnt,
+ void* user_data, ucp_send_nbx_callback_t cb,
+ const ucs_memory_type_t memory_type);
+ arrow::Status SendTagIov(ucp_tag_t tag, const ucp_dt_iov_t* iov, const
size_t iov_cnt,
+ void* user_data, ucp_send_nbx_callback_t cb,
+ const ucs_memory_type_t memory_type);
+ arrow::Status SendTagSync(ucp_tag_t tag, const void* buffer, const size_t
count);
+
+ protected:
+ static void err_cb(void* arg, ucp_ep_h ep, ucs_status_t status) {
+ if (!is_ignorable_disconnect_error(status)) {
+ ARROW_LOG(DEBUG) << FromUcsStatus("error handling callback",
status).ToString();
+ }
+ Connection* cnxn = reinterpret_cast<Connection*>(arg);
+ cnxn->closed_ = true;
+ }
+
+ inline arrow::Status CheckClosed() {
+ if (!remote_endpoint_) {
+ return arrow::Status::Invalid("connection is closed");
+ }
+ return arrow::Status::OK();
+ }
+
+ private:
+ std::shared_ptr<utils::UcpWorker> ucp_worker_;
+ ucp_ep_h remote_endpoint_;
+
+ bool closed_{false};
+};
+} // namespace utils
diff --git a/dissociated-ipc/ucx_server.cc b/dissociated-ipc/ucx_server.cc
new file mode 100644
index 0000000..22398a7
--- /dev/null
+++ b/dissociated-ipc/ucx_server.cc
@@ -0,0 +1,280 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arpa/inet.h>
+#include <netdb.h>
+
+#include "ucx_server.h"
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/string.h"
+
+namespace {
+arrow::Result<std::shared_ptr<utils::UcpContext>> init_ucx(
+ struct sockaddr_storage connect_addr) {
+ ucp_config_t* ucp_config;
+ ucp_params_t ucp_params;
+ ucs_status_t status = ucp_config_read(nullptr, nullptr, &ucp_config);
+ RETURN_NOT_OK(utils::FromUcsStatus("ucp_config_read", status));
+
+ // if location is ipv6, adjust config
+ if (connect_addr.ss_family == AF_INET6) {
+ status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6");
+ RETURN_NOT_OK(utils::FromUcsStatus("ucp_config_modify", status));
+ }
+
+ std::memset(&ucp_params, 0, sizeof(ucp_params));
+ ucp_params.field_mask =
+ UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_NAME |
UCP_PARAM_FIELD_MT_WORKERS_SHARED;
+ ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_TAG | UCP_FEATURE_RMA |
+ UCP_FEATURE_WAKEUP | UCP_FEATURE_STREAM;
+ ucp_params.mt_workers_shared = UCS_THREAD_MODE_MULTI;
+ ucp_params.name = "cuda-flight-ucx";
+
+ ucp_context_h ucp_context;
+ status = ucp_init(&ucp_params, ucp_config, &ucp_context);
+ ucp_config_release(ucp_config);
+ RETURN_NOT_OK(utils::FromUcsStatus("ucp_init", status));
+ return std::make_shared<utils::UcpContext>(ucp_context);
+}
+
+arrow::Result<std::shared_ptr<utils::UcpWorker>> create_listener_worker(
+ std::shared_ptr<utils::UcpContext> ctx) {
+ ucp_worker_params_t worker_params;
+ ucs_status_t status;
+
+ std::memset(&worker_params, 0, sizeof(worker_params));
+ worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
+ worker_params.thread_mode = UCS_THREAD_MODE_SINGLE;
+
+ ucp_worker_h worker;
+ status = ucp_worker_create(ctx->get(), &worker_params, &worker);
+ RETURN_NOT_OK(utils::FromUcsStatus("ucp_worker_create", status));
+ return std::make_shared<utils::UcpWorker>(std::move(ctx), worker);
+}
+} // namespace
+
+arrow::Status UcxServer::Init(const std::string& host, const int32_t port) {
+ struct sockaddr_storage listen_addr;
+ ARROW_ASSIGN_OR_RAISE(auto addrlen, utils::to_sockaddr(host, port,
&listen_addr));
+
+ ARROW_ASSIGN_OR_RAISE(ucp_context_, init_ucx(listen_addr));
+ ARROW_ASSIGN_OR_RAISE(worker_conn_, create_listener_worker(ucp_context_));
+
+ {
+ ucp_listener_params_t params;
+ ucs_status_t status;
+
+ params.field_mask =
+ UCP_LISTENER_PARAM_FIELD_SOCK_ADDR |
UCP_LISTENER_PARAM_FIELD_CONN_HANDLER;
+ params.sockaddr.addr = reinterpret_cast<const sockaddr*>(&listen_addr);
+ params.sockaddr.addrlen = addrlen;
+ params.conn_handler.cb = HandleIncomingConnection;
+ params.conn_handler.arg = this;
+
+ status = ucp_listener_create(worker_conn_->get(), ¶ms, &listener_);
+ RETURN_NOT_OK(utils::FromUcsStatus("ucp_listener_create", status));
+
+ // get real address/port
+ ucp_listener_attr_t attr;
+ attr.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR;
+ status = ucp_listener_query(listener_, &attr);
+ RETURN_NOT_OK(utils::FromUcsStatus("ucp_listener_query", status));
+
+ std::string raw_uri = "ucx://";
+ if (host.find(":") != std::string::npos) {
+ raw_uri += '[';
+ raw_uri += host;
+ raw_uri += ']';
+ } else {
+ raw_uri += host;
+ }
+
+ using arrow::internal::ToChars;
+
+ raw_uri += ":";
+ raw_uri +=
+ ToChars(ntohs(reinterpret_cast<const
sockaddr_in*>(&attr.sockaddr)->sin_port));
+
+ ARROW_ASSIGN_OR_RAISE(location_, arrow::flight::Location::Parse(raw_uri));
+ }
+
+ {
+ listening_.store(true);
+ std::thread listener_thread(&UcxServer::DriveConnections, this);
+ listener_thread_.swap(listener_thread);
+ }
+
+ return arrow::Status::OK();
+}
+
+arrow::Status UcxServer::Wait() {
+ std::lock_guard<std::mutex> guard(join_mutex_);
+ try {
+ listener_thread_.join();
+ } catch (const std::system_error& e) {
+ if (e.code() != std::errc::invalid_argument) {
+ return arrow::Status::UnknownError("could not Wait(): ", e.what());
+ }
+ // else server wasn't running anyways
+ }
+ return arrow::Status::OK();
+}
+
+arrow::Status UcxServer::Shutdown() {
+ if (!listening_.load()) return arrow::Status::OK();
+
+ arrow::Status status;
+ // wait for current running things to finish
+ listening_.store(false);
+ RETURN_NOT_OK(
+ utils::FromUcsStatus("ucp_worker_signal",
ucp_worker_signal(worker_conn_->get())));
+ status &= Wait();
+
+ {
+ // reject all pending connections
+ std::lock_guard<std::mutex> guard(pending_connections_mutex_);
+ while (!pending_connections_.empty()) {
+ status &= utils::FromUcsStatus(
+ "ucp_listener_reject",
+ ucp_listener_reject(listener_, pending_connections_.front()));
+ pending_connections_.pop();
+ }
+ ucp_listener_destroy(listener_);
+ worker_conn_.reset();
+ }
+
+ ucp_context_.reset();
+ return status;
+}
+
+void UcxServer::DriveConnections() {
+ while (listening_.load()) {
+ // wait for server to recieve connection request from client
+ while (ucp_worker_progress(worker_conn_->get())) {
+ }
+ {
+ // check for requests in queue
+ std::lock_guard<std::mutex> guard(pending_connections_mutex_);
+ while (!pending_connections_.empty()) {
+ ucp_conn_request_h request = pending_connections_.front();
+ pending_connections_.pop();
+
+ std::thread(&UcxServer::HandleConnection, this, request).detach();
+ }
+ }
+
+ // check listening_ in case we're shutting down.
+ // it's possible that shutdown was called while we were in
+ // ucp_worker_progress above, in which case if we don't check
+ // listening_ here, we'll enter ucp_worker_wait and get stuck.
+ if (!listening_.load()) break;
+ auto status = ucp_worker_wait(worker_conn_->get());
+ if (status != UCS_OK) {
+ ARROW_LOG(WARNING) << utils::FromUcsStatus("ucp_worker_wait",
status).ToString();
+ }
+ }
+}
+
+void UcxServer::HandleConnection(ucp_conn_request_h request) {
+ using arrow::internal::ToChars;
+ std::string peer = "unknown:" + ToChars(counter_++);
+ {
+ ucp_conn_request_attr_t request_attr;
+ std::memset(&request_attr, 0, sizeof(request_attr));
+ request_attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR;
+ if (ucp_conn_request_query(request, &request_attr) == UCS_OK) {
+
ARROW_UNUSED(utils::SockaddrToString(request_attr.client_address).Value(&peer));
+ }
+ }
+ ARROW_LOG(DEBUG) << peer << ": Received connection request";
+
+ auto maybe_worker = CreateWorker();
+ if (!maybe_worker.ok()) {
+ ARROW_LOG(ERROR) << peer << ": failed to create worker"
+ << maybe_worker.status().ToString();
+ auto status = ucp_listener_reject(listener_, request);
+ if (status != UCS_OK) {
+ ARROW_LOG(ERROR) << peer << ": "
+ << utils::FromUcsStatus("ucp_listener_reject",
status).ToString();
+ }
+ return;
+ }
+
+ auto worker = maybe_worker.MoveValueUnsafe();
+ worker->conn_ = std::make_unique<utils::Connection>(worker->worker_);
+ auto status = worker->conn_->CreateEndpoint(request);
+ if (!status.ok()) {
+ ARROW_LOG(ERROR) << peer << ": failed to create endpoint and connection: "
+ << status.ToString();
+ return;
+ }
+
+ if (cuda_context_) {
+ auto result =
cuCtxPushCurrent(reinterpret_cast<CUcontext>(cuda_context_->handle()));
+ if (result != CUDA_SUCCESS) {
+ const char* err_name = "\0";
+ const char* err_string = "\0";
+ cuGetErrorName(result, &err_name);
+ cuGetErrorString(result, &err_string);
+ ARROW_LOG(ERROR) << peer << ": failed pushing cuda context on thread: "
<< err_name
+ << " - " << err_string;
+ return;
+ }
+ }
+
+ auto st = do_work(worker.get());
+ if (!st.ok()) {
+ ARROW_LOG(ERROR) << peer << ": error from do_work: " << st.ToString();
+ }
+
+ while (!worker->conn_->is_closed()) {
+ worker->conn_->Progress();
+ }
+
+ // clean up
+ status = worker->conn_->Close();
+ if (!status.ok()) {
+ ARROW_LOG(ERROR) << peer
+ << ": failed to close worker connection: " <<
status.ToString();
+ }
+ worker->worker_.reset();
+ worker->conn_.reset();
+ ARROW_LOG(DEBUG) << peer << ": disconnected";
+}
+
+arrow::Result<std::shared_ptr<UcxServer::ClientWorker>>
UcxServer::CreateWorker() {
+ auto worker = std::make_shared<ClientWorker>();
+
+ ucp_worker_params_t worker_params;
+ std::memset(&worker_params, 0, sizeof(worker_params));
+ worker_params.field_mask =
+ UCP_WORKER_PARAM_FIELD_THREAD_MODE | UCP_WORKER_PARAM_FIELD_FLAGS;
+ worker_params.thread_mode = UCS_THREAD_MODE_MULTI;
+ worker_params.flags = UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK;
+
+ ucp_worker_h ucp_worker;
+ ARROW_RETURN_NOT_OK(utils::FromUcsStatus(
+ "ucp_worker_create",
+ ucp_worker_create(ucp_context_->get(), &worker_params, &ucp_worker)));
+
+ worker->worker_ = std::make_shared<utils::UcpWorker>(ucp_context_,
ucp_worker);
+ ARROW_RETURN_NOT_OK(setup_handlers(worker.get()));
+ return worker;
+}
diff --git a/dissociated-ipc/ucx_server.h b/dissociated-ipc/ucx_server.h
new file mode 100644
index 0000000..94290bf
--- /dev/null
+++ b/dissociated-ipc/ucx_server.h
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic>
+#include <memory>
+#include <queue>
+#include <string>
+#include <thread>
+#include <utility>
+
+#include "ucx_conn.h"
+#include "ucx_utils.h"
+
+#include "arrow/flight/types.h"
+#include "arrow/gpu/cuda_context.h"
+#include "arrow/status.h"
+
+class UcxServer {
+ public:
+ virtual ~UcxServer() = default;
+ arrow::Status Init(const std::string& host, const int32_t port);
+
+ arrow::Status Wait();
+ virtual arrow::Status Shutdown();
+
+ inline void set_cuda_context(std::shared_ptr<arrow::cuda::CudaContext> ctx) {
+ cuda_context_ = std::move(ctx);
+ }
+
+ protected:
+ inline arrow::flight::Location location() const { return location_; }
+
+ struct ClientWorker {
+ std::shared_ptr<utils::UcpWorker> worker_;
+ std::unique_ptr<utils::Connection> conn_;
+ };
+
+ virtual arrow::Status setup_handlers(ClientWorker* worker) = 0;
+ virtual arrow::Status do_work(ClientWorker* worker) = 0;
+
+ private:
+ static void HandleIncomingConnection(ucp_conn_request_h connection_request,
+ void* data) {
+ UcxServer* server = reinterpret_cast<UcxServer*>(data);
+ server->EnqueueClient(connection_request);
+ }
+
+ void DriveConnections();
+ void EnqueueClient(ucp_conn_request_h connection_request) {
+ std::lock_guard<std::mutex> guard(pending_connections_mutex_);
+ pending_connections_.push(connection_request);
+ }
+
+ void HandleConnection(ucp_conn_request_h request);
+ arrow::Result<std::shared_ptr<ClientWorker>> CreateWorker();
+
+ protected:
+ std::atomic<size_t> counter_{0};
+ arrow::flight::Location location_;
+ std::shared_ptr<utils::UcpContext> ucp_context_;
+ std::shared_ptr<utils::UcpWorker> worker_conn_;
+ ucp_listener_h listener_;
+
+ std::atomic<bool> listening_;
+ std::thread listener_thread_;
+ // std::thread::join cannot be called concurrently
+ std::mutex join_mutex_;
+ std::mutex pending_connections_mutex_;
+ std::queue<ucp_conn_request_h> pending_connections_;
+
+ std::shared_ptr<arrow::cuda::CudaContext> cuda_context_;
+};
diff --git a/dissociated-ipc/ucx_utils.cc b/dissociated-ipc/ucx_utils.cc
new file mode 100644
index 0000000..580ccf9
--- /dev/null
+++ b/dissociated-ipc/ucx_utils.cc
@@ -0,0 +1,287 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arpa/inet.h>
+#include <netdb.h>
+#include <ucp/api/ucp.h>
+
+#include "arrow/status.h"
+#include "arrow/util/io_util.h"
+#include "ucx_utils.h"
+
+namespace utils {
+constexpr char UcxStatusDetail::kTypeId[];
+
+arrow::Result<size_t> to_sockaddr(const std::string& host, const int32_t port,
+ struct sockaddr_storage* addr) {
+ if (host.empty()) {
+ return arrow::Status::Invalid("Must provide a host");
+ } else if (port < 0) {
+ return arrow::Status::Invalid("Must provide a port");
+ }
+
+ std::memset(addr, 0, sizeof(*addr));
+
+ struct addrinfo* info = nullptr;
+ int err = getaddrinfo(host.c_str(), /*service*/ nullptr, /*hints*/ nullptr,
&info);
+ if (err != 0) {
+ if (err == EAI_SYSTEM) {
+ return arrow::internal::IOErrorFromErrno(errno, "[getaddrinfo] Failure
resolving ",
+ host);
+ } else {
+ return arrow::Status::IOError("[getaddrinfo] Failure resolving ", host,
": ",
+ gai_strerror(err));
+ }
+ }
+
+ struct addrinfo* cur_info = info;
+ while (cur_info) {
+ if (cur_info->ai_family != AF_INET && cur_info->ai_family != AF_INET6) {
+ cur_info = cur_info->ai_next;
+ continue;
+ }
+
+ std::memcpy(addr, cur_info->ai_addr, cur_info->ai_addrlen);
+ if (cur_info->ai_family == AF_INET) {
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = htons(port);
+ } else if (cur_info->ai_family == AF_INET6) {
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = htons(port);
+ }
+
+ size_t addrlen = cur_info->ai_addrlen;
+ freeaddrinfo(info);
+ return addrlen;
+ }
+
+ if (info) freeaddrinfo(info);
+ return arrow::Status::IOError("[getaddrinfo] Failure resolving ", host,
+ ": no results of a supported family returned");
+}
+
+arrow::Result<std::string> SockaddrToString(const struct sockaddr_storage&
address) {
+ std::string result = "";
+ if (address.ss_family != AF_INET && address.ss_family != AF_INET6) {
+ return arrow::Status::NotImplemented("unknown address family");
+ }
+
+ uint16_t port = 0;
+ if (address.ss_family == AF_INET) {
+ result.resize(INET_ADDRSTRLEN + 1);
+ const auto* in_addr = reinterpret_cast<const struct
sockaddr_in*>(&address);
+ if (!inet_ntop(address.ss_family, &in_addr->sin_addr, &result[0],
INET_ADDRSTRLEN)) {
+ return arrow::internal::IOErrorFromErrno(errno,
+ "could not convert address to a
string");
+ }
+ port = ntohs(in_addr->sin_port);
+ } else {
+ result.resize(INET6_ADDRSTRLEN + 1);
+ const auto* in6_addr = reinterpret_cast<const struct
sockaddr_in6*>(&address);
+ if (!inet_ntop(address.ss_family, &in6_addr->sin6_addr, &result[0],
+ INET6_ADDRSTRLEN)) {
+ return arrow::internal::IOErrorFromErrno(errno,
+ "could not convert address to
string");
+ }
+ port = ntohs(in6_addr->sin6_port);
+ }
+
+ const size_t pos = result.find('\0');
+ DCHECK_NE(pos, std::string::npos);
+ result[pos] = ':';
+ result.resize(pos + 1);
+ result += std::to_string(port);
+
+ return result;
+}
+
+std::string UcxStatusDetail::ToString() const { return
ucs_status_string(status_); }
+ucs_status_t UcxStatusDetail::Unwrap(const arrow::Status& status) {
+ if (!status.detail() || status.detail()->type_id() != kTypeId) return UCS_OK;
+ return dynamic_cast<const UcxStatusDetail*>(status.detail().get())->status_;
+}
+
+arrow::Status FromUcsStatus(const std::string& context, ucs_status_t
ucs_status) {
+ switch (ucs_status) {
+ case UCS_OK:
+ return arrow::Status::OK();
+ case UCS_INPROGRESS:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_INPROGRESS ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_MESSAGE:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_NO_MESSAGE ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_RESOURCE:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_NO_RESOURCE ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_IO_ERROR:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_IO_ERROR ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_MEMORY:
+ return arrow::Status::OutOfMemory(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_NO_MEMORY ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_INVALID_PARAM:
+ return arrow::Status::Invalid(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_INVALID_PARAM ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_UNREACHABLE:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_UNREACHABLE ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_INVALID_ADDR:
+ return arrow::Status::Invalid(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_INVALID_ADDR ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_NOT_IMPLEMENTED:
+ return arrow::Status::NotImplemented(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_NOT_IMPLEMENTED ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_MESSAGE_TRUNCATED:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_MESSAGE_TRUNCATED ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_PROGRESS:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_NO_PROGRESS ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_BUFFER_TOO_SMALL:
+ return arrow::Status::Invalid(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_BUFFER_TOO_SMALL ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_ELEM:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_NO_ELEM ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_SOME_CONNECTS_FAILED:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_SOME_CONNECTS_FAILED ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_DEVICE:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_NO_DEVICE ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_BUSY:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_BUSY ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_CANCELED:
+ return arrow::Status::Cancelled(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_CANCELED ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_SHMEM_SEGMENT:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_SHMEM_SEGMENT ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_ALREADY_EXISTS:
+ return arrow::Status::AlreadyExists(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_ALREADY_EXISTS ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_OUT_OF_RANGE:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_OUT_OF_RANGE ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_TIMED_OUT:
+ return arrow::Status::Cancelled(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_TIMED_OUT ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_EXCEEDS_LIMIT:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_EXCEEDS_LIMIT ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_UNSUPPORTED:
+ return arrow::Status::NotImplemented(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_UNSUPPORTED ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_REJECTED:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_REJECTED ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_NOT_CONNECTED:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_NOT_CONNECTED ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_CONNECTION_RESET:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_CONNECTION_RESET ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_FIRST_LINK_FAILURE:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_FIRST_LINK_FAILURE ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_LAST_LINK_FAILURE:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_LAST_LINK_FAILURE ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_FIRST_ENDPOINT_FAILURE:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_FIRST_ENDPOINT_FAILURE ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_LAST_ENDPOINT_FAILURE:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_LAST_ENDPOINT_FAILURE ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_ENDPOINT_TIMEOUT:
+ return arrow::Status::IOError(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ":
",
+ "UCS_ERR_ENDPOINT_TIMEOUT ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ case UCS_ERR_LAST:
+ return arrow::Status::IOError(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_LAST ",
ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ default:
+ return arrow::Status::UnknownError(
+ context, ": Unknown UCX error: ",
static_cast<int32_t>(ucs_status), " ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<UcxStatusDetail>(ucs_status));
+ }
+}
+} // namespace utils
diff --git a/dissociated-ipc/ucx_utils.h b/dissociated-ipc/ucx_utils.h
new file mode 100644
index 0000000..7bc9228
--- /dev/null
+++ b/dissociated-ipc/ucx_utils.h
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include <ucp/api/ucp.h>
+
+#include <arrow/buffer.h>
+#include <arrow/status.h>
+#include <arrow/util/endian.h>
+#include <arrow/util/logging.h>
+#include <arrow/util/ubsan.h>
+
+namespace utils {
+static inline void Uint32ToBytesLE(const uint32_t in, uint8_t* out) {
+ arrow::util::SafeStore(out, arrow::bit_util::ToLittleEndian(in));
+}
+
+static inline uint32_t BytesToUint32LE(const uint8_t* in) {
+ return
arrow::bit_util::FromLittleEndian(arrow::util::SafeLoadAs<uint32_t>(in));
+}
+
+class UcpContext final {
+ public:
+ UcpContext() = default;
+ explicit UcpContext(ucp_context_h context) : ucp_context_(context) {}
+ ~UcpContext() {
+ if (ucp_context_) ucp_cleanup(ucp_context_);
+ ucp_context_ = nullptr;
+ }
+
+ ucp_context_h get() const {
+ DCHECK(ucp_context_);
+ return ucp_context_;
+ }
+
+ private:
+ ucp_context_h ucp_context_{nullptr};
+};
+
+class UcpWorker final {
+ public:
+ UcpWorker() = default;
+ UcpWorker(std::shared_ptr<UcpContext> context, ucp_worker_h worker)
+ : ucp_context_(std::move(context)), ucp_worker_(worker) {}
+ ~UcpWorker() {
+ if (ucp_worker_) ucp_worker_destroy(ucp_worker_);
+ ucp_worker_ = nullptr;
+ }
+
+ ucp_worker_h get() const { return ucp_worker_; }
+ const UcpContext& context() const { return *ucp_context_; }
+
+ private:
+ ucp_worker_h ucp_worker_{nullptr};
+ std::shared_ptr<UcpContext> ucp_context_;
+};
+
+class UcxStatusDetail : public arrow::StatusDetail {
+ public:
+ explicit UcxStatusDetail(ucs_status_t status) : status_(status) {}
+ static constexpr char const kTypeId[] = "ucx::UcxStatusDetail";
+
+ const char* type_id() const override { return kTypeId; }
+ std::string ToString() const override;
+ static ucs_status_t Unwrap(const arrow::Status& status);
+
+ private:
+ ucs_status_t status_;
+};
+
+arrow::Status FromUcsStatus(const std::string& context, ucs_status_t
ucs_status);
+
+class UcxDataBuffer : public arrow::Buffer {
+ public:
+ UcxDataBuffer(std::shared_ptr<UcpWorker> worker, void* data, const size_t
size)
+ : arrow::Buffer(reinterpret_cast<uint8_t*>(data),
static_cast<int64_t>(size)),
+ worker_(std::move(worker)) {}
+ ~UcxDataBuffer() override {
+ ucp_am_data_release(worker_->get(),
+ const_cast<void*>(reinterpret_cast<const
void*>(data())));
+ }
+
+ private:
+ std::shared_ptr<UcpWorker> worker_;
+};
+
+arrow::Result<size_t> to_sockaddr(const std::string& host, const int32_t port,
+ struct sockaddr_storage* addr);
+arrow::Result<std::string> SockaddrToString(const struct sockaddr_storage&
address);
+
+static inline bool is_ignorable_disconnect_error(ucs_status_t ucs_status) {
+ // not connected, connection reset: we're already disconnected
+ // timeout: most likely disconnected, but we can't tell from our end
+ switch (ucs_status) {
+ case UCS_OK:
+ case UCS_ERR_ENDPOINT_TIMEOUT:
+ case UCS_ERR_NOT_CONNECTED:
+ case UCS_ERR_CONNECTION_RESET:
+ return true;
+ }
+ return false;
+}
+} // namespace utils