https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81462
>From a79501ebced4a3410c3a28c6555973bb45156e76 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Wed, 14 Feb 2024 16:20:30 +0000 Subject: [PATCH] [mlir][Transforms][NFC] Simplify `ArgConverter` state * When converting a block signature, `ArgConverter` creates a new block with the new signature and moves all operation from the old block to the new block. The new block is temporarily inserted into a region that is stored in `regionMapping`. The old block is not yet deleted, so that the conversion can be rolled back. `regionMapping` is not needed. Instead of moving the old block to a temporary region, it can just be unlinked. Block erasures are handles in the same way in the dialect conversion. * `regionToConverter` is a mapping from regions to type converter. That field is never accessed within `ArgConverter`. It should be stored in `ConversionPatternRewriterImpl` instead. --- .../Transforms/Utils/DialectConversion.cpp | 79 ++++++------------- 1 file changed, 22 insertions(+), 57 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 5206a65608ba14..67b076b295eae8 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -343,23 +343,6 @@ struct ArgConverter { const TypeConverter *converter; }; - /// Return if the signature of the given block has already been converted. - bool hasBeenConverted(Block *block) const { - return conversionInfo.count(block) || convertedBlocks.count(block); - } - - /// Set the type converter to use for the given region. - void setConverter(Region *region, const TypeConverter *typeConverter) { - assert(typeConverter && "expected valid type converter"); - regionToConverter[region] = typeConverter; - } - - /// Return the type converter to use for the given region, or null if there - /// isn't one. - const TypeConverter *getConverter(Region *region) { - return regionToConverter.lookup(region); - } - //===--------------------------------------------------------------------===// // Rewrite Application //===--------------------------------------------------------------------===// @@ -409,24 +392,10 @@ struct ArgConverter { ConversionValueMapping &mapping, SmallVectorImpl<BlockArgument> &argReplacements); - /// Insert a new conversion into the cache. - void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); - /// A collection of blocks that have had their arguments converted. This is a /// map from the new replacement block, back to the original block. llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo; - /// The set of original blocks that were converted. - DenseSet<Block *> convertedBlocks; - - /// A mapping from valid regions, to those containing the original blocks of a - /// conversion. - DenseMap<Region *, std::unique_ptr<Region>> regionMapping; - - /// A mapping of regions to type converters that should be used when - /// converting the arguments of blocks within that region. - DenseMap<Region *, const TypeConverter *> regionToConverter; - /// The pattern rewriter to use when materializing conversions. PatternRewriter &rewriter; @@ -474,12 +443,12 @@ void ArgConverter::discardRewrites(Block *block) { block->getArgument(i).dropAllUses(); block->replaceAllUsesWith(origBlock); - // Move the operations back the original block and the delete the new block. + // Move the operations back the original block, move the original block back + // into its original location and the delete the new block. origBlock->getOperations().splice(origBlock->end(), block->getOperations()); - origBlock->moveBefore(block); + block->getParent()->getBlocks().insert(Region::iterator(block), origBlock); block->erase(); - convertedBlocks.erase(origBlock); conversionInfo.erase(it); } @@ -510,6 +479,9 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { mapping.lookupOrDefault(castValue, origArg.getType())); } } + + delete origBlock; + blockInfo.origBlock = nullptr; } } @@ -572,9 +544,11 @@ FailureOr<Block *> ArgConverter::convertSignature( Block *block, const TypeConverter *converter, ConversionValueMapping &mapping, SmallVectorImpl<BlockArgument> &argReplacements) { - // Check if the block was already converted. If the block is detached, - // conservatively assume it is going to be deleted. - if (hasBeenConverted(block) || !block->getParent()) + // Check if the block was already converted. + // * If the block is mapped in `conversionInfo`, it is a converted block. + // * If the block is detached, conservatively assume that it is going to be + // deleted; it is likely the old block (before it was converted). + if (conversionInfo.count(block) || !block->getParent()) return block; // If a converter wasn't provided, and the block wasn't already converted, // there is nothing we can do. @@ -603,6 +577,9 @@ Block *ArgConverter::applySignatureConversion( // signature. Block *newBlock = block->splitBlock(block->begin()); block->replaceAllUsesWith(newBlock); + // Unlink the block, but do not erase it yet, so that the change can be rolled + // back. + block->getParent()->getBlocks().remove(block); // Map all new arguments to the location of the argument they originate from. SmallVector<Location> newLocs(convertedTypes.size(), @@ -679,24 +656,8 @@ Block *ArgConverter::applySignatureConversion( ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } - // Remove the original block from the region and return the new one. - insertConversion(newBlock, std::move(info)); - return newBlock; -} - -void ArgConverter::insertConversion(Block *newBlock, - ConvertedBlockInfo &&info) { - // Get a region to insert the old block. - Region *region = newBlock->getParent(); - std::unique_ptr<Region> &mappedRegion = regionMapping[region]; - if (!mappedRegion) - mappedRegion = std::make_unique<Region>(region->getParentOp()); - - // Move the original block to the mapped region and emplace the conversion. - mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), - info.origBlock->getIterator()); - convertedBlocks.insert(info.origBlock); conversionInfo.insert({newBlock, std::move(info)}); + return newBlock; } //===----------------------------------------------------------------------===// @@ -1196,6 +1157,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// active. const TypeConverter *currentTypeConverter = nullptr; + /// A mapping of regions to type converters that should be used when + /// converting the arguments of blocks within that region. + DenseMap<Region *, const TypeConverter *> regionToConverter; + /// This allows the user to collect the match failure message. function_ref<void(Diagnostic &)> notifyCallback; @@ -1473,7 +1438,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion) { - argConverter.setConverter(region, &converter); + regionToConverter[region] = &converter; if (region->empty()) return nullptr; @@ -1488,7 +1453,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( Region *region, const TypeConverter &converter, ArrayRef<TypeConverter::SignatureConversion> blockConversions) { - argConverter.setConverter(region, &converter); + regionToConverter[region] = &converter; if (region->empty()) return success(); @@ -2162,7 +2127,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the region of the block has a type converter, try to convert the block // directly. - if (auto *converter = impl.argConverter.getConverter(block->getParent())) { + if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { if (failed(impl.convertBlockSignature(block, converter))) { LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " "block")); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits