https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/121067
>From b8afb0446b85ed7b1e1646a64e0c1b79ef79b04a Mon Sep 17 00:00:00 2001 From: Maksim Levental <maksim.leven...@gmail.com> Date: Tue, 24 Dec 2024 15:27:31 -0500 Subject: [PATCH] [mlir][rocdl] Add AMDGPU-specific `cf.assert` lowering --- .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 72 ++++++++++++++++++- mlir/test/Integration/GPU/ROCM/assert.mlir | 39 ++++++++++ 2 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Integration/GPU/ROCM/assert.mlir diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index aaf00e51f49416..2b0d16a5defb1f 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -31,6 +31,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" @@ -195,6 +196,75 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { } }; +/// Based on +/// mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp#AssertOpToAssertfailLowering +/// Lowering of cf.assert into a conditional llvm.intr.trap plus gpu.printf with +/// the metadata (filename, fileline, assert msg). +struct AssertOpToBuiltinTrapLowering + : public ConvertOpToLLVMPattern<cf::AssertOp> { + using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = assertOp.getLoc(); + + // Split blocks and insert conditional branch. + // ^before: + // ... + // cf.cond_br %condition, ^after, ^assert + // ^assert: + // cf.assert + // cf.br ^after + // ^after: + // ... + Block *beforeBlock = assertOp->getBlock(); + Block *assertBlock = + rewriter.splitBlock(beforeBlock, assertOp->getIterator()); + Block *afterBlock = + rewriter.splitBlock(assertBlock, ++assertOp->getIterator()); + rewriter.setInsertionPointToEnd(beforeBlock); + rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock, + assertBlock); + rewriter.setInsertionPointToEnd(assertBlock); + rewriter.create<cf::BranchOp>(loc, afterBlock); + + // Continue cf.assert lowering. + rewriter.setInsertionPoint(assertOp); + + // Populate file name, file number and function name from the location of + // the AssertOp. + StringRef fileName = "(unknown)"; + StringRef funcName = "(unknown)"; + int32_t fileLine = 0; + if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) { + fileName = fileLineColLoc.getFilename().strref(); + fileLine = fileLineColLoc.getStartLine(); + } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) { + funcName = nameLoc.getName().strref(); + if (auto fileLineColLoc = + dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) { + fileName = fileLineColLoc.getFilename().strref(); + fileLine = fileLineColLoc.getStartLine(); + } + } + + Value assertLine = + rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), fileLine); + // interpolate the fmt str AOT because current gpu.printf lowering doesn't + // handle %s + llvm::Twine fmtStr = fileName + ":%u: " + funcName + + " Device-side assertion `" + assertOp.getMsg() + + "' failed.\n"; + rewriter.create<gpu::PrintfOp>(assertOp.getLoc(), + rewriter.getStringAttr(fmtStr), + ValueRange{assertLine}); + rewriter.replaceOpWithNewOp<LLVM::Trap>(assertOp); + + return success(); + } +}; + /// Import the GPU Ops to ROCDL Patterns. #include "GPUToROCDL.cpp.inc" @@ -297,7 +367,7 @@ struct LowerGpuOpsToROCDLOpsPass populateVectorToLLVMConversionPatterns(converter, llvmPatterns); populateMathToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); - cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns); + llvmPatterns.add<AssertOpToBuiltinTrapLowering>(converter); populateFuncToLLVMConversionPatterns(converter, llvmPatterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime); diff --git a/mlir/test/Integration/GPU/ROCM/assert.mlir b/mlir/test/Integration/GPU/ROCM/assert.mlir new file mode 100644 index 00000000000000..e1b07d454e61ac --- /dev/null +++ b/mlir/test/Integration/GPU/ROCM/assert.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-rocdl{index-bitwidth=32 runtime=HIP}),rocdl-attach-target{chip=%chip})' \ +// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_rocm_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void 2>&1 \ +// RUN: | FileCheck %s + +// CHECK-DAG: thread 0: print after passing assertion +// CHECK-DAG: thread 1: print after passing assertion +// CHECK-DAG: mlir/test/Integration/GPU/ROCM/assert.mlir:{{.*}}: (unknown) Device-side assertion `failing assertion' failed. +// CHECK-DAG: mlir/test/Integration/GPU/ROCM/assert.mlir:{{.*}}: (unknown) Device-side assertion `failing assertion' failed. +// CHECK-NOT: print after failing assertion + +module attributes {gpu.container_module} { +gpu.module @kernels { +gpu.func @test_assert(%c0: i1, %c1: i1) kernel { + %0 = gpu.thread_id x + cf.assert %c1, "passing assertion" + gpu.printf "thread %lld: print after passing assertion\n" %0 : index + cf.assert %c0, "failing assertion" + gpu.printf "thread %lld: print after failing assertion\n" %0 : index + gpu.return +} +} + +func.func @main() { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0_i1 = arith.constant 0 : i1 + %c1_i1 = arith.constant 1 : i1 + gpu.launch_func @kernels::@test_assert + blocks in (%c1, %c1, %c1) + threads in (%c2, %c1, %c1) + args(%c0_i1 : i1, %c1_i1 : i1) + return +} +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits