llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Replace a workaround in the implementation of `replaceAllUsesWith` in the no-rollback dialect conversion. This workaround was necessary for `restoreByValRefArgumentType` in the `func-to-llvm` lowering because there was no support for `replaceAllUsesExcept`. Support for this API has been added to the no-rollback driver, so the workaround can be dropped from that driver. The workaround is still in place for the rollback driver. Depends on #<!-- -->169606. --- Full diff: https://github.com/llvm/llvm-project/pull/169609.diff 4 Files Affected: - (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+10-2) - (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+12-21) - (modified) mlir/test/Transforms/test-convert-func-op.mlir (+2-1) - (modified) mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp (+10-1) ``````````diff diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 2220f61ed8a07..ddd94f5d03042 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -283,8 +283,16 @@ static void restoreByValRefArgumentType( Type resTy = typeConverter.convertType( cast<TypeAttr>(byValRefAttr->getValue()).getValue()); - Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg); - rewriter.replaceAllUsesWith(arg, valueArg); + auto loadOp = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg); + if (!rewriter.getConfig().allowPatternRollback) { + rewriter.replaceAllUsesExcept(arg, loadOp, loadOp); + } else { + // replaceAllUsesExcept is not supported in rollback mode. The rollback + // mode implementation has a workaround: certain replacements that would + // cause a dominance violation are skipped. + // TODO: Remove workaround. + rewriter.replaceAllUsesWith(arg, loadOp); + } } } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c9f1596c07cbe..ccc5b7cb6f229 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1205,17 +1205,14 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -/// Replace all uses of `from` with `repl`. -static void -performReplaceValue(RewriterBase &rewriter, Value from, Value repl, - function_ref<bool(OpOperand &)> functor = nullptr) { +void ReplaceValueRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter); + if (!repl) + return; + if (isa<BlockArgument>(repl)) { // `repl` is a block argument. Directly replace all uses. - if (functor) { - rewriter.replaceUsesWithIf(from, repl, functor); - } else { - rewriter.replaceAllUsesWith(from, repl); - } + rewriter.replaceAllUsesWith(value, repl); return; } @@ -1244,23 +1241,14 @@ performReplaceValue(RewriterBase &rewriter, Value from, Value repl, // `ConversionPatternRewriter` API with the normal `RewriterBase` API. Operation *replOp = repl.getDefiningOp(); Block *replBlock = replOp->getBlock(); - rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) { + rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); bool result = user->getBlock() != replBlock || replOp->isBeforeInBlock(user); - if (functor) - result &= functor(operand); return result; }); } -void ReplaceValueRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter); - if (!repl) - return; - performReplaceValue(rewriter, value, repl); -} - void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); #ifndef NDEBUG @@ -2000,8 +1988,11 @@ void ConversionPatternRewriterImpl::replaceValueUses( Value repl = repls.front(); if (!repl) return; - - performReplaceValue(r, from, repl, functor); + if (functor) { + r.replaceUsesWithIf(from, repl, functor); + } else { + r.replaceAllUsesWith(from, repl); + } return; } diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir index 180f16a32991b..14c15ecbe77f0 100644 --- a/mlir/test/Transforms/test-convert-func-op.mlir +++ b/mlir/test/Transforms/test-convert-func-op.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=1" --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=0" --split-input-file | FileCheck %s // CHECK-LABEL: llvm.func @add func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } { diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp index 75168dde93130..897b11b65b6f2 100644 --- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp +++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp @@ -68,6 +68,9 @@ struct TestConvertFuncOp : public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp) + TestConvertFuncOp() = default; + TestConvertFuncOp(const TestConvertFuncOp &other) : PassWrapper(other) {} + void getDependentDialects(DialectRegistry ®istry) const final { registry.insert<LLVM::LLVMDialect>(); } @@ -92,10 +95,16 @@ struct TestConvertFuncOp patterns.add<ReturnOpConversion>(typeConverter); LLVMConversionTarget target(getContext()); + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + std::move(patterns), config))) signalPassFailure(); } + + Option<bool> allowPatternRollback{*this, "allow-pattern-rollback", + llvm::cl::desc("Allow pattern rollback"), + llvm::cl::init(true)}; }; } // namespace `````````` </details> https://github.com/llvm/llvm-project/pull/169609 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
