Author: Valentin Clement (バレンタイン クレメン) Date: 2024-08-14T13:44:34-07:00 New Revision: ec137c84529a30b16c051f6b633a7a7538bdc46c
URL: https://github.com/llvm/llvm-project/commit/ec137c84529a30b16c051f6b633a7a7538bdc46c DIFF: https://github.com/llvm/llvm-project/commit/ec137c84529a30b16c051f6b633a7a7538bdc46c.diff LOG: Revert "[flang][cuda] Use cuda runtime API (#103488)" This reverts commit 00ab8a6a4c3811c50a9dc9626e6fa067fdfcd474. Added: Modified: flang/include/flang/Runtime/CUDA/allocator.h flang/runtime/CUDA/CMakeLists.txt flang/runtime/CUDA/allocator.cpp flang/unittests/Runtime/CUDA/AllocatorCUF.cpp Removed: ################################################################################ diff --git a/flang/include/flang/Runtime/CUDA/allocator.h b/flang/include/flang/Runtime/CUDA/allocator.h index 4527c9f18fa054..f0bfc1548e6458 100644 --- a/flang/include/flang/Runtime/CUDA/allocator.h +++ b/flang/include/flang/Runtime/CUDA/allocator.h @@ -13,10 +13,11 @@ #include "flang/Runtime/entry-names.h" #define CUDA_REPORT_IF_ERROR(expr) \ - [](cudaError_t err) { \ - if (err == cudaSuccess) \ + [](CUresult result) { \ + if (!result) \ return; \ - const char *name = cudaGetErrorName(err); \ + const char *name = nullptr; \ + cuGetErrorName(result, &name); \ if (!name) \ name = "<unknown>"; \ Terminator terminator{__FILE__, __LINE__}; \ diff --git a/flang/runtime/CUDA/CMakeLists.txt b/flang/runtime/CUDA/CMakeLists.txt index 53c5b8823c56b0..88243536139e46 100644 --- a/flang/runtime/CUDA/CMakeLists.txt +++ b/flang/runtime/CUDA/CMakeLists.txt @@ -7,20 +7,14 @@ #===------------------------------------------------------------------------===# include_directories(${CUDAToolkit_INCLUDE_DIRS}) +find_library(CUDA_RUNTIME_LIBRARY cuda HINTS ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES} REQUIRED) add_flang_library(CufRuntime allocator.cpp descriptor.cpp ) - -if (BUILD_SHARED_LIBS) - set(CUF_LIBRARY ${CUDA_LIBRARIES}) -else() - set(CUF_LIBRARY ${CUDA_cudart_static_LIBRARY}) -endif() - target_link_libraries(CufRuntime PRIVATE FortranRuntime - ${CUF_LIBRARY} + ${CUDA_RUNTIME_LIBRARY} ) diff --git a/flang/runtime/CUDA/allocator.cpp b/flang/runtime/CUDA/allocator.cpp index d4a473d58e86cd..bd657b800c61e8 100644 --- a/flang/runtime/CUDA/allocator.cpp +++ b/flang/runtime/CUDA/allocator.cpp @@ -15,7 +15,7 @@ #include "flang/ISO_Fortran_binding_wrapper.h" #include "flang/Runtime/allocator-registry.h" -#include "cuda_runtime.h" +#include "cuda.h" namespace Fortran::runtime::cuda { extern "C" { @@ -34,28 +34,32 @@ void RTDEF(CUFRegisterAllocator)() { void *CUFAllocPinned(std::size_t sizeInBytes) { void *p; - CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&p, sizeInBytes)); + CUDA_REPORT_IF_ERROR(cuMemAllocHost(&p, sizeInBytes)); return p; } -void CUFFreePinned(void *p) { CUDA_REPORT_IF_ERROR(cudaFreeHost(p)); } +void CUFFreePinned(void *p) { CUDA_REPORT_IF_ERROR(cuMemFreeHost(p)); } void *CUFAllocDevice(std::size_t sizeInBytes) { - void *p; - CUDA_REPORT_IF_ERROR(cudaMalloc(&p, sizeInBytes)); - return p; + CUdeviceptr p = 0; + CUDA_REPORT_IF_ERROR(cuMemAlloc(&p, sizeInBytes)); + return reinterpret_cast<void *>(p); } -void CUFFreeDevice(void *p) { CUDA_REPORT_IF_ERROR(cudaFree(p)); } +void CUFFreeDevice(void *p) { + CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(p))); +} void *CUFAllocManaged(std::size_t sizeInBytes) { - void *p; + CUdeviceptr p = 0; CUDA_REPORT_IF_ERROR( - cudaMallocManaged((void **)&p, sizeInBytes, cudaMemAttachGlobal)); + cuMemAllocManaged(&p, sizeInBytes, CU_MEM_ATTACH_GLOBAL)); return reinterpret_cast<void *>(p); } -void CUFFreeManaged(void *p) { CUDA_REPORT_IF_ERROR(cudaFree(p)); } +void CUFFreeManaged(void *p) { + CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(p))); +} void *CUFAllocUnified(std::size_t sizeInBytes) { // Call alloc managed for the time being. diff --git a/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp b/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp index b51ff0ac006cc6..9f5ec289ee8f74 100644 --- a/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp +++ b/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp @@ -14,7 +14,7 @@ #include "flang/Runtime/allocatable.h" #include "flang/Runtime/allocator-registry.h" -#include "cuda_runtime.h" +#include "cuda.h" using namespace Fortran::runtime; using namespace Fortran::runtime::cuda; @@ -25,9 +25,38 @@ static OwningPtr<Descriptor> createAllocatable( CFI_attribute_allocatable); } +thread_local static int32_t defaultDevice = 0; + +CUdevice getDefaultCuDevice() { + CUdevice device; + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); + return device; +} + +class ScopedContext { +public: + ScopedContext() { + // Static reference to CUDA primary context for device ordinal + // defaultDevice. + static CUcontext context = [] { + CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); + CUcontext ctx; + // Note: this does not affect the current context. + CUDA_REPORT_IF_ERROR( + cuDevicePrimaryCtxRetain(&ctx, getDefaultCuDevice())); + return ctx; + }(); + + CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context)); + } + + ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } +}; + TEST(AllocatableCUFTest, SimpleDeviceAllocate) { using Fortran::common::TypeCategory; RTNAME(CUFRegisterAllocator)(); + ScopedContext ctx; // REAL(4), DEVICE, ALLOCATABLE :: a(:) auto a{createAllocatable(TypeCategory::Real, 4)}; a->SetAllocIdx(kDeviceAllocatorPos); @@ -45,6 +74,7 @@ TEST(AllocatableCUFTest, SimpleDeviceAllocate) { TEST(AllocatableCUFTest, SimplePinnedAllocate) { using Fortran::common::TypeCategory; RTNAME(CUFRegisterAllocator)(); + ScopedContext ctx; // INTEGER(4), PINNED, ALLOCATABLE :: a(:) auto a{createAllocatable(TypeCategory::Integer, 4)}; EXPECT_FALSE(a->HasAddendum()); @@ -63,6 +93,7 @@ TEST(AllocatableCUFTest, SimplePinnedAllocate) { TEST(AllocatableCUFTest, DescriptorAllocationTest) { using Fortran::common::TypeCategory; RTNAME(CUFRegisterAllocator)(); + ScopedContext ctx; // REAL(4), DEVICE, ALLOCATABLE :: a(:) auto a{createAllocatable(TypeCategory::Real, 4)}; Descriptor *desc = nullptr; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits