https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/177016
>From 8c8627ccec95fbdd60252430f4bb438a22906529 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak <[email protected]> Date: Tue, 20 Jan 2026 02:09:23 +0000 Subject: [PATCH] [mlir][MemRef] Make fold-memref-alias-ops use memref interfaces This replaces the large switch-cases and operation-specific patterns in FoldMemRefAliashops with patterns that use the new IndexedAccessOpInterface and IndexedMemCopyOpInterface, which will allow us to remove the memref transforms' dependency on the NVGPU dialect. This does also resolve some bugs and potential unsoundnesses: 1. We will no longer fold in expand_shape into vector.load or vector.transfer_read in cases where that would alter the strides between dimensions in multi-dimensional loads. For example, if we have a `vector.load %e[%i, %j, %k] : memref<8x8x9xf32>, vector<2x3xf32>` where %e is `expand_shape %m [[0], [1], [2. 3]] : memref<8x8x3x3xf32> to 8x8x9xf32, we will no longer fold in that shape, since that would change which value would be read (the previous patterns tried to account for this but failed). 2. Subviews that have non-unit strides in positions that aren't being meaningfully accessed (ex. the outer dimensions of a 1-D vector load) are now folded 3. While it is still not possible to fold a collapse_shape with a transfer_read or transfer_write if it would affect the transfer dimensions, the fold will now occur if it would not 4. DMA operations (nvgpu async copy, memref.dma_start) now support expand_shape and collapse_shape folding in addition to subview. 5. Loading or storing a 1xN vector from a memref where the 1 and N are the result of expanding the same dimension will now fold into loading a vector of length N and a shape cast. 6. An issue where IR would be created before pattern failure has been resolved. Assisted-By: Claude code (generating copies of similar patterns, test generation) --- .../MemRef/Transforms/FoldMemRefAliasOps.cpp | 829 +++++++++--------- .../Dialect/MemRef/fold-memref-alias-ops.mlir | 293 ++++++- 2 files changed, 697 insertions(+), 425 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 3cacb7e29263b..b1eb7958d2793 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -13,14 +13,14 @@ #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" -#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -43,38 +43,23 @@ using namespace mlir; // Utility functions //===----------------------------------------------------------------------===// -/// Helpers to access the memref operand for each op. -template <typename LoadOrStoreOpTy> -static Value getMemRefOperand(LoadOrStoreOpTy op) { - return op.getMemref(); +/// Deterimine if the last N indices of `reassocitaion` are trivial - that is, +/// check if they all contain exactly one dimension to collape/expand into. +static bool +hasTrivialReassociationSuffix(ArrayRef<ReassociationIndices> reassocs, + int64_t n) { + if (n <= 0) + return true; + return llvm::all_of( + reassocs.take_back(n), + [&](const ReassociationIndices &indices) { return indices.size() == 1; }); } -static Value getMemRefOperand(vector::TransferReadOp op) { - return op.getBase(); -} - -static Value getMemRefOperand(nvgpu::LdMatrixOp op) { - return op.getSrcMemref(); -} - -static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } - -static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); } - -static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); } - -static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); } - -static Value getMemRefOperand(vector::TransferWriteOp op) { - return op.getBase(); -} - -static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) { - return op.getSrcMemref(); -} - -static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) { - return op.getDstMemref(); +static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) { + if (n <= 0) + return true; + return llvm::all_of(subview.getStaticStrides().take_back(n), + [](int64_t s) { return s == 1; }); } //===----------------------------------------------------------------------===// @@ -82,63 +67,120 @@ static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) { //===----------------------------------------------------------------------===// namespace { -/// Merges subview operation with load/transferRead operation. -template <typename OpTy> -class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { -public: - using OpRewritePattern<OpTy>::OpRewritePattern; +/// Merges subview operations with load/store like operations unless such a +/// merger would cause the strides between dimensions accessed by that operaton +/// to change. +struct AccessOpOfSubViewOpFolder final + : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> { + using Base::Base; + + LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op, + PatternRewriter &rewriter) const override; +}; - LogicalResult matchAndRewrite(OpTy loadOp, +/// Merge a memref.expand_shape operation with an operation that accesses a +/// memref by index unless that operation accesss more than one dimension of +/// memory and any dimension other than the outermost dimension accessed this +/// way would be merged. This prevents issuses from arising with, say, a +/// vector.load of a 4x2 vector having the two traliing dimensions of the access +/// get merged. +struct AccessOpOfExpandShapeOpFolder final + : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> { + using Base::Base; + + LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const override; }; -/// Merges expand_shape operation with load/transferRead operation. -template <typename OpTy> -class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { -public: - using OpRewritePattern<OpTy>::OpRewritePattern; +/// Merges an operation that accesses a memref by index with a +/// memref.collapse_shape, unless this would break apart a dimension other than +/// the outermost one that an operation accesses. This prevents, for example, +/// transforming a load of a vector 3x8 vector from a 6x8 memref into a load +/// from a 3x4x2 memref (as this would require special handling and could lead +/// to invalid IR if that higher-dimensional memref comes from a subview) but +/// does permit turning a load of a length-8 vector from a 3x8 memref into a +/// load from a 6x4x2 one. +struct AccessOpOfCollapseShapeOpFolder final + : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> { + using Base::Base; + + LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op, + PatternRewriter &rewriter) const override; +}; + +/// Merges memref.subview operations present on the source or destination +/// operands of indexed memory copy operations (DMA operations) into those +/// operations. This is perfromed unconditionally, since folding in a subview +/// cannot change the starting position of the copy, which is what the +/// memref/index pair represent in DMA operations. +struct IndexedMemCopyOpOfSubViewOpFolder final + : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> { + using Base::Base; - LogicalResult matchAndRewrite(OpTy loadOp, + LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const override; }; -/// Merges collapse_shape operation with load/transferRead operation. -template <typename OpTy> -class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { -public: - using OpRewritePattern<OpTy>::OpRewritePattern; +/// Merges memref.expand_shape operations that are present on the source or +/// destination of an indexed memory copy/DMA into the memref/index arguments of +/// that DMA. As with subviews, this can be done unconditionally. +struct IndexedMemCopyOpOfExpandShapeOpFolder final + : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> { + using Base::Base; - LogicalResult matchAndRewrite(OpTy loadOp, + LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const override; }; -/// Merges subview operation with store/transferWriteOp operation. -template <typename OpTy> -class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { -public: - using OpRewritePattern<OpTy>::OpRewritePattern; +/// Merges memref.collapse_shape operations that are present on the source or +/// destination of an indexed memory copy/DMA into the memref/index arguments of +/// that DMA. As with subviews, this can be done unconditionally. +struct IndexedMemCopyOpOfCollapseShapeOpFolder final + : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> { + using Base::Base; - LogicalResult matchAndRewrite(OpTy storeOp, + LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const override; }; -/// Merges expand_shape operation with store/transferWriteOp operation. -template <typename OpTy> -class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { -public: - using OpRewritePattern<OpTy>::OpRewritePattern; +/// Merges memref.subview ops on the base argument to vector transfer operations +/// into the base and indices of that transfer if: +/// - The subview has unit strides on transfer dimensions +/// - All the transfer dimensions are in-bounds +/// This will correctly update said permutation map to account for dropped +/// dimensions in rank-reducing subviews. +struct TransferOpOfSubViewOpFolder final + : OpInterfaceRewritePattern<VectorTransferOpInterface> { + using Base::Base; + + LogicalResult matchAndRewrite(VectorTransferOpInterface op, + PatternRewriter &rewriter) const override; +}; - LogicalResult matchAndRewrite(OpTy storeOp, +/// Merges memref.expand_shape ops that create the base of a vector transfer +/// operation into the base and indices of that transfer. Does not act when the +/// permutation map is not a minor identy, a dimension is potentially out of +/// bounds, or if it would merge two dimensions that are both transfer +/// dimensions. +/// TODO: become more sophisticated about length-1 dimensions that are the +/// result of an expansion becoming broadcasts. +struct TransferOpOfExpandShapeOpFolder final + : OpInterfaceRewritePattern<VectorTransferOpInterface> { + using Base::Base; + + LogicalResult matchAndRewrite(VectorTransferOpInterface op, PatternRewriter &rewriter) const override; }; -/// Merges collapse_shape operation with store/transferWriteOp operation. -template <typename OpTy> -class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { -public: - using OpRewritePattern<OpTy>::OpRewritePattern; +/// Merges memref.collapse_shape ops that create the base of a vector transfer +/// operation into the base and indices of that transfer. Does not act when the +/// permutation map is not trivial, a dimension could be performing out of +/// bounds reads, or if it would break apart a transfer dimension. +struct TransferOpOfCollapseShapeOpFolder final + : OpInterfaceRewritePattern<VectorTransferOpInterface> { + using Base::Base; - LogicalResult matchAndRewrite(OpTy storeOp, + LogicalResult matchAndRewrite(VectorTransferOpInterface op, PatternRewriter &rewriter) const override; }; @@ -184,399 +226,338 @@ class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> { return success(); } }; - -/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern -/// is folds subview on src and dst memref of the copy. -class NVGPUAsyncCopyOpSubViewOpFolder final - : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> { -public: - using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp, - PatternRewriter &rewriter) const override; -}; } // namespace -template <typename XferOp> -static LogicalResult -preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, - memref::SubViewOp subviewOp) { - static_assert( - !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value, - "must be a vector transfer op"); - if (xferOp.hasOutOfBoundsDim()) - return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); - if (!subviewOp.hasUnitStride()) { +LogicalResult +AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op, + PatternRewriter &rewriter) const { + auto subview = op.getMemref().getDefiningOp<memref::SubViewOp>(); + if (!subview) + return rewriter.notifyMatchFailure(op, "not accessing a subview"); + + SmallVector<int64_t> accessedShape = op.getAccessedShape(); + // Note the subtle difference between accesedShape = {1} and accessedShape = + // {} here. The former prevents us from fdolding in a subview that doesn't + // have a unit stride on the final dimension, while the latter does not (since + // it indices scalar accesss). + int64_t accessedDims = accessedShape.size(); + if (!hasTrailingUnitStrides(subview, accessedDims)) return rewriter.notifyMatchFailure( - xferOp, "non-1 stride subview, need to track strides in folded memref"); + op, "non-unit stride on accessed dimensions"); + + llvm::SmallBitVector droppedDims = subview.getDroppedDims(); + int64_t sourceRank = subview.getSourceType().getRank(); + + // Ignore outermost access dimension - we only care about dropped dimensions + // between the accessed op's results, as those could break the accessing op's + // sematics. + int64_t secondAccessedDim = sourceRank - (accessedDims - 1); + if (secondAccessedDim < sourceRank) { + for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) { + if (droppedDims.test(d)) + return rewriter.notifyMatchFailure( + op, "reintroducing dropped dimension " + Twine(d) + + " would break access op semantics"); + } } - return success(); -} - -static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, - Operation *op, - memref::SubViewOp subviewOp) { - return success(); -} - -static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, - vector::TransferReadOp readOp, - memref::SubViewOp subviewOp) { - return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp); -} - -static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, - vector::TransferWriteOp writeOp, - memref::SubViewOp subviewOp) { - return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp); -} - -template <typename OpTy> -LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite( - OpTy loadOp, PatternRewriter &rewriter) const { - auto subViewOp = - getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); - - if (!subViewOp) - return rewriter.notifyMatchFailure(loadOp, "not a subview producer"); - - LogicalResult preconditionResult = - preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp); - if (failed(preconditionResult)) - return preconditionResult; SmallVector<Value> sourceIndices; affine::resolveIndicesIntoOpWithOffsetsAndStrides( - rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(), - subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), - loadOp.getIndices(), sourceIndices); - - llvm::TypeSwitch<Operation *, void>(loadOp) - .Case([&](memref::LoadOp op) { - rewriter.replaceOpWithNewOp<memref::LoadOp>( - loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal()); - }) - .Case([&](vector::LoadOp op) { - rewriter.replaceOpWithNewOp<vector::LoadOp>( - op, op.getType(), subViewOp.getSource(), sourceIndices); - }) - .Case([&](vector::MaskedLoadOp op) { - rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( - op, op.getType(), subViewOp.getSource(), sourceIndices, - op.getMask(), op.getPassThru()); - }) - .Case([&](vector::TransferReadOp op) { - rewriter.replaceOpWithNewOp<vector::TransferReadOp>( - op, op.getVectorType(), subViewOp.getSource(), sourceIndices, - AffineMapAttr::get(expandDimsToRank( - op.getPermutationMap(), subViewOp.getSourceType().getRank(), - subViewOp.getDroppedDims())), - op.getPadding(), op.getMask(), op.getInBoundsAttr()); - }) - .Case([&](gpu::SubgroupMmaLoadMatrixOp op) { - rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>( - op, op.getType(), subViewOp.getSource(), sourceIndices, - op.getLeadDimension(), op.getTransposeAttr()); - }) - .Case([&](nvgpu::LdMatrixOp op) { - rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>( - op, op.getType(), subViewOp.getSource(), sourceIndices, - op.getTranspose(), op.getNumTiles()); - }) - .DefaultUnreachable("unexpected operation"); + rewriter, op.getLoc(), subview.getMixedOffsets(), + subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices); + + std::optional<SmallVector<Value>> newValues = + op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices); + if (newValues) + rewriter.replaceOp(op, *newValues); return success(); } -template <typename OpTy> -LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( - OpTy loadOp, PatternRewriter &rewriter) const { - auto expandShapeOp = - getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>(); - - if (!expandShapeOp) - return failure(); +LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite( + memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const { + auto expand = op.getMemref().getDefiningOp<memref::ExpandShapeOp>(); + if (!expand) + return rewriter.notifyMatchFailure(op, "not accessing an expand_shape"); + + SmallVector<int64_t> rawAccessedShape = op.getAccessedShape(); + ArrayRef<int64_t> accessedShape = rawAccessedShape; + // Cut off the leading dimension, since we don't care about monifying its + // strides. + if (!accessedShape.empty()) + accessedShape = accessedShape.drop_front(); + + auto reassocs = expand.getReassociationIndices(); + if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size())) + return rewriter.notifyMatchFailure( + op, + "expand_shape folding would merge semanvtically important dimensions"); SmallVector<Value> sourceIndices; - // memref.load guarantees that indexes start inbounds while the vector - // operations don't. This impacts if our linearization is `disjoint` - resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp, - loadOp.getIndices(), sourceIndices, - isa<memref::LoadOp>(loadOp.getOperation())); - - return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp) - .Case([&](memref::LoadOp op) { - rewriter.replaceOpWithNewOp<memref::LoadOp>( - loadOp, expandShapeOp.getViewSource(), sourceIndices, - op.getNontemporal()); - return success(); - }) - .Case([&](vector::LoadOp op) { - rewriter.replaceOpWithNewOp<vector::LoadOp>( - op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, - op.getNontemporal()); - return success(); - }) - .Case([&](vector::MaskedLoadOp op) { - rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( - op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, - op.getMask(), op.getPassThru()); - return success(); - }) - .Case([&](vector::TransferReadOp op) { - // We only support minor identity maps in the permutation attribute. - if (!op.getPermutationMap().isMinorIdentity()) - return failure(); - - // We only support the case where the source of the expand shape has - // rank greater than or equal to the vector rank. - const int64_t sourceRank = sourceIndices.size(); - const int64_t vectorRank = op.getVectorType().getRank(); - if (sourceRank < vectorRank) - return failure(); - - // We need to construct a new minor identity map since we will have lost - // some dimensions in folding away the expand shape. - auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank, - op.getContext()); - - rewriter.replaceOpWithNewOp<vector::TransferReadOp>( - op, op.getVectorType(), expandShapeOp.getViewSource(), - sourceIndices, minorIdMap, op.getPadding(), op.getMask(), - op.getInBounds()); - return success(); - }) - .DefaultUnreachable("unexpected operation"); + memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand, + op.getIndices(), sourceIndices, + op.hasInboundsIndices()); + + std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices( + rewriter, expand.getViewSource(), sourceIndices); + if (newValues) + rewriter.replaceOp(op, *newValues); + return success(); } -template <typename OpTy> -LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( - OpTy loadOp, PatternRewriter &rewriter) const { - auto collapseShapeOp = getMemRefOperand(loadOp) - .template getDefiningOp<memref::CollapseShapeOp>(); - - if (!collapseShapeOp) - return failure(); +LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite( + memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const { + auto collapse = op.getMemref().getDefiningOp<memref::CollapseShapeOp>(); + if (!collapse) + return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape"); + + SmallVector<int64_t> rawAccessedShape = op.getAccessedShape(); + ArrayRef<int64_t> accessedShape = rawAccessedShape; + // Cut off the leading dimension, since we don't care about its strides being + // modified and we know that the dimensions within its reassociation group, if + // it's non-trivial, must be contiguous. + if (!accessedShape.empty()) + accessedShape = accessedShape.drop_front(); + + auto reassocs = collapse.getReassociationIndices(); + if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size())) + return rewriter.notifyMatchFailure(op, + "collapse_shape folding would merge " + "semanvtically important dimensions"); SmallVector<Value> sourceIndices; - resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp, - loadOp.getIndices(), sourceIndices); - llvm::TypeSwitch<Operation *, void>(loadOp) - .Case([&](memref::LoadOp op) { - rewriter.replaceOpWithNewOp<memref::LoadOp>( - loadOp, collapseShapeOp.getViewSource(), sourceIndices, - op.getNontemporal()); - }) - .Case([&](vector::LoadOp op) { - rewriter.replaceOpWithNewOp<vector::LoadOp>( - op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, - op.getNontemporal()); - }) - .Case([&](vector::MaskedLoadOp op) { - rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( - op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, - op.getMask(), op.getPassThru()); - }) - .DefaultUnreachable("unexpected operation"); + memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse, + op.getIndices(), sourceIndices); + + std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices( + rewriter, collapse.getViewSource(), sourceIndices); + if (newValues) + rewriter.replaceOp(op, *newValues); return success(); } -template <typename OpTy> -LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite( - OpTy storeOp, PatternRewriter &rewriter) const { - auto subViewOp = - getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); - - if (!subViewOp) - return rewriter.notifyMatchFailure(storeOp, "not a subview producer"); - - LogicalResult preconditionResult = - preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp); - if (failed(preconditionResult)) - return preconditionResult; - - SmallVector<Value> sourceIndices; - affine::resolveIndicesIntoOpWithOffsetsAndStrides( - rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(), - subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), - storeOp.getIndices(), sourceIndices); - - llvm::TypeSwitch<Operation *, void>(storeOp) - .Case([&](memref::StoreOp op) { - rewriter.replaceOpWithNewOp<memref::StoreOp>( - op, op.getValue(), subViewOp.getSource(), sourceIndices, - op.getNontemporal()); - }) - .Case([&](vector::TransferWriteOp op) { - rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( - op, op.getValue(), subViewOp.getSource(), sourceIndices, - AffineMapAttr::get(expandDimsToRank( - op.getPermutationMap(), subViewOp.getSourceType().getRank(), - subViewOp.getDroppedDims())), - op.getMask(), op.getInBoundsAttr()); - }) - .Case([&](vector::StoreOp op) { - rewriter.replaceOpWithNewOp<vector::StoreOp>( - op, op.getValueToStore(), subViewOp.getSource(), sourceIndices); - }) - .Case([&](vector::MaskedStoreOp op) { - rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( - op, subViewOp.getSource(), sourceIndices, op.getMask(), - op.getValueToStore()); - }) - .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { - rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>( - op, op.getSrc(), subViewOp.getSource(), sourceIndices, - op.getLeadDimension(), op.getTransposeAttr()); - }) - .DefaultUnreachable("unexpected operation"); +LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite( + memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const { + auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>(); + auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>(); + if (!srcSubview && !dstSubview) + return rewriter.notifyMatchFailure( + op, "no subviews found on indexed copy inputs"); + + Value newSrc = op.getSrc(); + SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices()); + Value newDst = op.getDst(); + SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices()); + if (srcSubview) { + newSrc = srcSubview.getSource(); + newSrcIndices.clear(); + affine::resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, op.getLoc(), srcSubview.getMixedOffsets(), + srcSubview.getMixedStrides(), srcSubview.getDroppedDims(), + op.getSrcIndices(), newSrcIndices); + } + if (dstSubview) { + newDst = dstSubview.getSource(); + newDstIndices.clear(); + affine::resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, op.getLoc(), dstSubview.getMixedOffsets(), + dstSubview.getMixedStrides(), dstSubview.getDroppedDims(), + op.getDstIndices(), newDstIndices); + } + op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst, + newDstIndices); return success(); } -template <typename OpTy> -LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( - OpTy storeOp, PatternRewriter &rewriter) const { - auto expandShapeOp = - getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>(); - - if (!expandShapeOp) - return failure(); - - SmallVector<Value> sourceIndices; - // memref.store guarantees that indexes start inbounds while the vector - // operations don't. This impacts if our linearization is `disjoint` - resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp, - storeOp.getIndices(), sourceIndices, - isa<memref::StoreOp>(storeOp.getOperation())); - llvm::TypeSwitch<Operation *, void>(storeOp) - .Case([&](memref::StoreOp op) { - rewriter.replaceOpWithNewOp<memref::StoreOp>( - storeOp, op.getValueToStore(), expandShapeOp.getViewSource(), - sourceIndices, op.getNontemporal()); - }) - .Case([&](vector::StoreOp op) { - rewriter.replaceOpWithNewOp<vector::StoreOp>( - op, op.getValueToStore(), expandShapeOp.getViewSource(), - sourceIndices, op.getNontemporal()); - }) - .Case([&](vector::MaskedStoreOp op) { - rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( - op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(), - op.getValueToStore()); - }) - .DefaultUnreachable("unexpected operation"); +LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite( + memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const { + auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>(); + auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>(); + if (!srcExpand && !dstExpand) + return rewriter.notifyMatchFailure( + op, "no expand_shapes found on indexed copy inputs"); + + Value newSrc = op.getSrc(); + SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices()); + Value newDst = op.getDst(); + SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices()); + if (srcExpand) { + newSrc = srcExpand.getViewSource(); + newSrcIndices.clear(); + memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, srcExpand, + op.getSrcIndices(), newSrcIndices, + /*startsInbounds=*/true); + } + if (dstExpand) { + newDst = dstExpand.getViewSource(); + newDstIndices.clear(); + memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, dstExpand, + op.getDstIndices(), newDstIndices, + /*startsInbounds=*/true); + } + op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst, + newDstIndices); return success(); } -template <typename OpTy> -LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( - OpTy storeOp, PatternRewriter &rewriter) const { - auto collapseShapeOp = getMemRefOperand(storeOp) - .template getDefiningOp<memref::CollapseShapeOp>(); - - if (!collapseShapeOp) - return failure(); - - SmallVector<Value> sourceIndices; - resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp, - storeOp.getIndices(), sourceIndices); - llvm::TypeSwitch<Operation *, void>(storeOp) - .Case([&](memref::StoreOp op) { - rewriter.replaceOpWithNewOp<memref::StoreOp>( - storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(), - sourceIndices, op.getNontemporal()); - }) - .Case([&](vector::StoreOp op) { - rewriter.replaceOpWithNewOp<vector::StoreOp>( - op, op.getValueToStore(), collapseShapeOp.getViewSource(), - sourceIndices, op.getNontemporal()); - }) - .Case([&](vector::MaskedStoreOp op) { - rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( - op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(), - op.getValueToStore()); - }) - .DefaultUnreachable("unexpected operation"); +LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite( + memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const { + auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>(); + auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>(); + if (!srcCollapse && !dstCollapse) + return rewriter.notifyMatchFailure( + op, "no collapse_shapes found on indexed copy inputs"); + + Value newSrc = op.getSrc(); + SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices()); + Value newDst = op.getDst(); + SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices()); + if (srcCollapse) { + newSrc = srcCollapse.getViewSource(); + newSrcIndices.clear(); + memref::resolveSourceIndicesCollapseShape( + op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices); + } + if (dstCollapse) { + newDst = dstCollapse.getViewSource(); + newDstIndices.clear(); + memref::resolveSourceIndicesCollapseShape( + op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices); + } + op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst, + newDstIndices); return success(); } -LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite( - nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const { - - LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n"); - - auto srcSubViewOp = - copyOp.getSrc().template getDefiningOp<memref::SubViewOp>(); - auto dstSubViewOp = - copyOp.getDst().template getDefiningOp<memref::SubViewOp>(); - - if (!(srcSubViewOp || dstSubViewOp)) - return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for " - "source or destination"); - - // If the source is a subview, we need to resolve the indices. - SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(), - copyOp.getSrcIndices().end()); +LogicalResult +TransferOpOfSubViewOpFolder::matchAndRewrite(VectorTransferOpInterface op, + PatternRewriter &rewriter) const { + auto subview = op.getBase().getDefiningOp<memref::SubViewOp>(); + if (!subview) + return rewriter.notifyMatchFailure(op, "not accessing a subview"); + + AffineMap perm = op.getPermutationMap(); + // Note: no identity permutation check here, since subview foldin can handle + // complex permutations because it doesn't merge or split any individual + // dimension. + if (op.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(op, "out of bounds dimension"); + VectorType vecTy = op.getVectorType(); + // Because we know the permutation map is a minor identity, we know that the + // last N dimensions must have unit stride, where N is the vector rank. + + if (!hasTrailingUnitStrides(subview, vecTy.getRank())) + return rewriter.notifyMatchFailure(subview, "non-unit stride within last " + + Twine(vecTy.getRank()) + + " dimensions"); + + AffineMap newPerm = expandDimsToRank(perm, subview.getSourceType().getRank(), + subview.getDroppedDims()); + + if (failed(op.mayUpdateStartingPosition(subview.getSourceType(), newPerm))) + return rewriter.notifyMatchFailure(subview, + "failed op-specific preconditions"); + + SmallVector<Value> newIndices; + affine::resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, op.getLoc(), subview.getMixedOffsets(), + subview.getMixedStrides(), subview.getDroppedDims(), op.getIndices(), + newIndices); + op.updateStartingPosition(rewriter, subview.getSource(), newIndices, + AffineMapAttr::get(newPerm)); + return success(); +} - if (srcSubViewOp) { - LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n"); - affine::resolveIndicesIntoOpWithOffsetsAndStrides( - rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(), - srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(), - copyOp.getSrcIndices(), foldedSrcIndices); - } +LogicalResult TransferOpOfExpandShapeOpFolder::matchAndRewrite( + VectorTransferOpInterface op, PatternRewriter &rewriter) const { + auto expand = op.getBase().getDefiningOp<memref::ExpandShapeOp>(); + if (!expand) + return rewriter.notifyMatchFailure(op, "not accessing an expand_shape"); + + if (!op.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(op, + "non-minor identity permutation map"); + if (op.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(op, "out of bounds dimension"); + + int64_t srcRank = expand.getSrc().getType().getRank(); + int64_t vecRank = op.getVectorType().getRank(); + if (srcRank < vecRank) + return rewriter.notifyMatchFailure(op, + "source rank is less than vector rank"); + + SmallVector<ReassociationIndices> reassocs = expand.getReassociationIndices(); + if (!hasTrivialReassociationSuffix(reassocs, vecRank - 1)) + return rewriter.notifyMatchFailure( + op, "expand_shape folding would merge two transfer dimensions"); + + AffineMap newPerm = + AffineMap::getMinorIdentityMap(srcRank, vecRank, op.getContext()); + if (failed(op.mayUpdateStartingPosition(expand.getSrc().getType(), newPerm))) + return rewriter.notifyMatchFailure(op, "failed op-specific preconditions"); + + SmallVector<Value> newIndices; + // We can use a disjoint linearization if we aren't masking, because then all + // indicators show that the start position will be in bounds. + memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand, + op.getIndices(), newIndices, + /*startsInbounds=*/!op.getMask()); + + op.updateStartingPosition(rewriter, expand.getViewSource(), newIndices, + AffineMapAttr::get(newPerm)); + return success(); +} - // If the destination is a subview, we need to resolve the indices. - SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(), - copyOp.getDstIndices().end()); +LogicalResult TransferOpOfCollapseShapeOpFolder::matchAndRewrite( + VectorTransferOpInterface op, PatternRewriter &rewriter) const { + auto collapse = op.getBase().getDefiningOp<memref::CollapseShapeOp>(); + if (!collapse) + return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape"); + + if (!op.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(op, + "non-minor identity permutation map"); + + if (op.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(op, "out of bounds dimension"); + + int64_t srcRank = collapse.getSrc().getType().getRank(); + int64_t vecRank = op.getVectorType().getRank(); + if (srcRank < vecRank) + return rewriter.notifyMatchFailure(op, + "source rank is less than vector rank"); + + // Note: no - 1 on the rank here. While we could treat the collapse of [1, 1, + // N] into N as a specila case, that is left as future work for those who need + // such a pattern. + SmallVector<ReassociationIndices> reassocs = + collapse.getReassociationIndices(); + if (!hasTrivialReassociationSuffix(reassocs, vecRank)) + return rewriter.notifyMatchFailure( + op, "collapse_shape folding would split a transfer dimension"); - if (dstSubViewOp) { - LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n"); - affine::resolveIndicesIntoOpWithOffsetsAndStrides( - rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(), - dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(), - copyOp.getDstIndices(), foldedDstIndices); - } + AffineMap newPerm = + AffineMap::getMinorIdentityMap(srcRank, vecRank, op.getContext()); + if (failed( + op.mayUpdateStartingPosition(collapse.getSrc().getType(), newPerm))) + return rewriter.notifyMatchFailure(op, "failed op-specific preconditions"); - // Replace the copy op with a new copy op that uses the source and destination - // of the subview. - rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>( - copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()), - (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()), - foldedDstIndices, - (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()), - foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(), - copyOp.getBypassL1Attr()); + SmallVector<Value> newIndices; + memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse, + op.getIndices(), newIndices); + op.updateStartingPosition(rewriter, collapse.getViewSource(), newIndices, + AffineMapAttr::get(newPerm)); return success(); } void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { - patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>, - LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>, - LoadOpOfSubViewOpFolder<vector::LoadOp>, - LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>, - LoadOpOfSubViewOpFolder<vector::TransferReadOp>, - LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>, - StoreOpOfSubViewOpFolder<memref::StoreOp>, - StoreOpOfSubViewOpFolder<vector::TransferWriteOp>, - StoreOpOfSubViewOpFolder<vector::StoreOp>, - StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>, - StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>, - LoadOpOfExpandShapeOpFolder<memref::LoadOp>, - LoadOpOfExpandShapeOpFolder<vector::LoadOp>, - LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>, - LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>, - StoreOpOfExpandShapeOpFolder<memref::StoreOp>, - StoreOpOfExpandShapeOpFolder<vector::StoreOp>, - StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>, - LoadOpOfCollapseShapeOpFolder<memref::LoadOp>, - LoadOpOfCollapseShapeOpFolder<vector::LoadOp>, - LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>, - StoreOpOfCollapseShapeOpFolder<memref::StoreOp>, - StoreOpOfCollapseShapeOpFolder<vector::StoreOp>, - StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>, - SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>( - patterns.getContext()); + patterns + .add<AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder, + AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder, + IndexedMemCopyOpOfExpandShapeOpFolder, + IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder, + TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder, + SubViewOfSubViewFolder>(patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 93e5ba462584a..0a2cc436cebb3 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -815,7 +815,7 @@ func.func @fold_vector_transfer_read_expand_shape( // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK: %[[C0:.*]] = arith.constant 0 // CHECK: %[[PAD:.*]] = ub.poison : f32 -// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8) +// CHECK: %[[IDX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[C0]]] by (4, 8) // CHECK: vector.transfer_read %[[ARG0]][%[[IDX]]], %[[PAD]] {in_bounds = [true]} // ----- @@ -911,3 +911,294 @@ func.func @fold_vector_maskedstore_collapse_shape( // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32> // CHECK: %[[IDXS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (4, 8) // CHECK: vector.maskedstore %[[ARG0]][%[[IDXS]]#0, %[[IDXS]]#1], %[[ARG3]], %[[ARG4]] + +// ----- + +func.func @fold_subview_non_unit_stride_with_vector_load( + %arg0 : memref<24x64xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xf32> { + %c0 = arith.constant 0 : index + %0 = memref.subview %arg0[%arg1, %arg2][12, 8][2, 1] : memref<24x64xf32> to memref<12x8xf32, strided<[128, 1], offset: ?>> + %1 = vector.load %0[%arg3, %c0] : memref<12x8xf32, strided<[128, 1], offset: ?>>, vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> +// CHECK: func @fold_subview_non_unit_stride_with_vector_load +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<24x64xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK: %[[I0:.+]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG3]]] +// CHECK: vector.load %[[ARG0]][%[[I0]], %[[ARG2]]] + +// ----- + +func.func @no_fold_expand_shape_2d_vector_load_matching_ranks( + %arg0 : memref<4x16xf32>, %arg1 : index) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 4, 4] : memref<4x16xf32> into memref<4x4x4xf32> + %1 = vector.load %0[%arg1, %c0, %c0] : memref<4x4x4xf32>, vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// CHECK-LABEL: func @no_fold_expand_shape_2d_vector_load_matching_ranks +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x16xf32> +// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [4, 4, 4] : memref<4x16xf32> into memref<4x4x4xf32> +// CHECK: vector.load + +// ----- + +func.func @fold_expand_shape_2d_vector_load( + %arg0 : memref<64x8xf32>, %arg1 : index) -> vector<2x8xf32> { + %c0 = arith.constant 0 : index + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 16, 8] : memref<64x8xf32> into memref<4x16x8xf32> + %1 = vector.load %0[%arg1, %c0, %c0] : memref<4x16x8xf32>, vector<2x8xf32> + return %1 : vector<2x8xf32> +} + +// CHECK-LABEL: func @fold_expand_shape_2d_vector_load +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<64x8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[C0:.*]] = arith.constant 0 +// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 16) +// CHECK: vector.load %[[ARG0]][%[[IDX]], %[[C0]]] + +// ----- + +func.func @no_fold_collapse_shape_transfer_read( + %arg0 : memref<4x4x8xf32>, %arg1 : index) -> vector<4x8xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<4x4x8xf32> into memref<16x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true, true]} : memref<16x8xf32>, vector<4x8xf32> + return %1 : vector<4x8xf32> +} + +// CHECK-LABEL: func @no_fold_collapse_shape_transfer_read +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x4x8xf32> +// CHECK: memref.collapse_shape %[[ARG0]] +// CHECK: vector.transfer_read + +// ----- + +func.func @fold_collapse_shape_transfer_read( + %arg0 : memref<4x4x8xf32>, %arg1 : index) -> vector<8xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<4x4x8xf32> into memref<16x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true]} : memref<16x8xf32>, vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-LABEL: func @fold_collapse_shape_transfer_read +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x4x8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[C0:.*]] = arith.constant 0 +// CHECK: %[[PAD:.*]] = ub.poison : f32 +// CHECK: %[[IDXS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (4, 4) +// CHECK: vector.transfer_read %[[ARG0]][%[[IDXS]]#0, %[[IDXS]]#1, %[[C0]]], %[[PAD]] {in_bounds = [true]} + +// ----- + +func.func @fold_dma_start_subview_src( + %src : memref<128x64xf32>, %dst : memref<32xf32, 1>, %tag : memref<1xi32>, + %off0 : index, %off1 : index) { + %c0 = arith.constant 0 : index + %num_elements = arith.constant 32 : index + %subview = memref.subview %src[%off0, %off1][32, 32][1, 1] : memref<128x64xf32> to memref<32x32xf32, strided<[64, 1], offset: ?>> + memref.dma_start %subview[%c0, %c0], %dst[%c0], %num_elements, %tag[%c0] : memref<32x32xf32, strided<[64, 1], offset: ?>>, memref<32xf32, 1>, memref<1xi32> + return +} + +// CHECK-LABEL: func @fold_dma_start_subview_src +// CHECK-SAME: %[[SRC:[a-zA-Z0-9_]+]]: memref<128x64xf32> +// CHECK-SAME: %[[DST:[a-zA-Z0-9_]+]]: memref<32xf32, 1> +// CHECK-SAME: %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32> +// CHECK-SAME: %[[OFF0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[OFF1:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[NUM:.*]] = arith.constant 32 +// CHECK: memref.dma_start %[[SRC]][%[[OFF0]], %[[OFF1]]], %[[DST]][%[[C0]]], %[[NUM]], %[[TAG]][%[[C0]]] + +// ----- + +func.func @fold_dma_start_subview_dst( + %src : memref<32xf32>, %dst : memref<128x64xf32, 1>, %tag : memref<1xi32>, + %off0 : index, %off1 : index) { + %c0 = arith.constant 0 : index + %num_elements = arith.constant 32 : index + %subview = memref.subview %dst[%off0, %off1][32, 32][1, 1] : memref<128x64xf32, 1> to memref<32x32xf32, strided<[64, 1], offset: ?>, 1> + memref.dma_start %src[%c0], %subview[%c0, %c0], %num_elements, %tag[%c0] : memref<32xf32>, memref<32x32xf32, strided<[64, 1], offset: ?>, 1>, memref<1xi32> + return +} +// CHECK-LABEL: func @fold_dma_start_subview_dst +// CHECK-SAME: %[[SRC:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[DST:[a-zA-Z0-9_]+]]: memref<128x64xf32, 1> +// CHECK-SAME: %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32> +// CHECK-SAME: %[[OFF0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[OFF1:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[NUM:.*]] = arith.constant 32 +// CHECK: memref.dma_start %[[SRC]][%[[C0]]], %[[DST]][%[[OFF0]], %[[OFF1]]], %[[NUM]], %[[TAG]][%[[C0]]] + +// ----- + +func.func @fold_dma_start_expand_shape_src( + %src : memref<32xf32>, %dst : memref<8xf32, 1>, %tag : memref<1xi32>, + %idx : index) { + %c0 = arith.constant 0 : index + %num_elements = arith.constant 8 : index + %expand = memref.expand_shape %src [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + memref.dma_start %expand[%idx, %c0], %dst[%c0], %num_elements, %tag[%c0] : memref<4x8xf32>, memref<8xf32, 1>, memref<1xi32> + return +} + +// CHECK-LABEL: func @fold_dma_start_expand_shape_src +// CHECK-SAME: %[[SRC:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[DST:[a-zA-Z0-9_]+]]: memref<8xf32, 1> +// CHECK-SAME: %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32> +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[NUM:.*]] = arith.constant 8 +// CHECK: %[[I:.*]] = affine.linearize_index disjoint [%[[IDX]], %[[C0]]] by (4, 8) +// CHECK: memref.dma_start %[[SRC]][%[[I]]], %[[DST]][%[[C0]]], %[[NUM]], %[[TAG]][%[[C0]]] + +// ----- + +func.func @fold_dma_start_expand_shape_dst( + %src : memref<8xf32>, %dst : memref<32xf32, 1>, %tag : memref<1xi32>, + %idx : index) { + %c0 = arith.constant 0 : index + %num_elements = arith.constant 8 : index + %expand = memref.expand_shape %dst [[0, 1]] output_shape [4, 8] : memref<32xf32, 1> into memref<4x8xf32, 1> + memref.dma_start %src[%c0], %expand[%idx, %c0], %num_elements, %tag[%c0] : memref<8xf32>, memref<4x8xf32, 1>, memref<1xi32> + return +} + +// CHECK-LABEL: func @fold_dma_start_expand_shape_dst +// CHECK-SAME: %[[SRC:[a-zA-Z0-9_]+]]: memref<8xf32> +// CHECK-SAME: %[[DST:[a-zA-Z0-9_]+]]: memref<32xf32, 1> +// CHECK-SAME: %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32> +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[NUM:.*]] = arith.constant 8 +// CHECK: %[[I:.*]] = affine.linearize_index disjoint [%[[IDX]], %[[C0]]] by (4, 8) +// CHECK: memref.dma_start %[[SRC]][%[[C0]]], %[[DST]][%[[I]]], %[[NUM]], %[[TAG]][%[[C0]]] + +// ----- + +func.func @fold_dma_start_collapse_shape_src( + %src : memref<4x8xf32>, %dst : memref<8xf32, 1>, %tag : memref<1xi32>, + %idx : index) { + %c0 = arith.constant 0 : index + %num_elements = arith.constant 8 : index + %collapse = memref.collapse_shape %src [[0, 1]] : memref<4x8xf32> into memref<32xf32> + memref.dma_start %collapse[%idx], %dst[%c0], %num_elements, %tag[%c0] : memref<32xf32>, memref<8xf32, 1>, memref<1xi32> + return +} + +// CHECK-LABEL: func @fold_dma_start_collapse_shape_src +// CHECK-SAME: %[[SRC:[a-zA-Z0-9_]+]]: memref<4x8xf32> +// CHECK-SAME: %[[DST:[a-zA-Z0-9_]+]]: memref<8xf32, 1> +// CHECK-SAME: %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32> +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[NUM:.*]] = arith.constant 8 +// CHECK: %[[IDXS:.*]]:2 = affine.delinearize_index %[[IDX]] into (4, 8) +// CHECK: memref.dma_start %[[SRC]][%[[IDXS]]#0, %[[IDXS]]#1], %[[DST]][%[[C0]]], %[[NUM]], %[[TAG]][%[[C0]]] + +// ----- + +func.func @fold_dma_start_collapse_shape_dst( + %src : memref<8xf32>, %dst : memref<4x8xf32, 1>, %tag : memref<1xi32>, + %idx : index) { + %c0 = arith.constant 0 : index + %num_elements = arith.constant 8 : index + %collapse = memref.collapse_shape %dst [[0, 1]] : memref<4x8xf32, 1> into memref<32xf32, 1> + memref.dma_start %src[%c0], %collapse[%idx], %num_elements, %tag[%c0] : memref<8xf32>, memref<32xf32, 1>, memref<1xi32> + return +} + +// CHECK-LABEL: func @fold_dma_start_collapse_shape_dst +// CHECK-SAME: %[[SRC:[a-zA-Z0-9_]+]]: memref<8xf32> +// CHECK-SAME: %[[DST:[a-zA-Z0-9_]+]]: memref<4x8xf32, 1> +// CHECK-SAME: %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32> +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[NUM:.*]] = arith.constant 8 +// CHECK: %[[IDXS:.*]]:2 = affine.delinearize_index %[[IDX]] into (4, 8) +// CHECK: memref.dma_start %[[SRC]][%[[C0]]], %[[DST]][%[[IDXS]]#0, %[[IDXS]]#1], %[[NUM]], %[[TAG]][%[[C0]]] + +// ----- + +func.func @fold_vector_load_expand_shape_unit_dims( + %arg0 : memref<64xf32>, %arg1 : index) -> vector<1x1x8xf32> { + %c0 = arith.constant 0 : index + %0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [1, 1, 64] : memref<64xf32> into memref<1x1x64xf32> + %1 = vector.load %0[%c0, %c0, %arg1] : memref<1x1x64xf32>, vector<1x1x8xf32> + return %1 : vector<1x1x8xf32> +} + +// CHECK-LABEL: func @fold_vector_load_expand_shape_unit_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<64xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[IDX:.*]] = affine.linearize_index [%{{.*}}, %{{.*}}, %[[ARG1]]] by (1, 1, 64) +// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[IDX]]] : memref<64xf32>, vector<8xf32> +// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf32> to vector<1x1x8xf32> +// CHECK: return %[[CAST]] + +// ----- + +func.func @fold_vector_store_expand_shape_unit_dims( + %arg0 : memref<64xf32>, %arg1 : index, %val : vector<1x1x8xf32>) { + %c0 = arith.constant 0 : index + %0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [1, 1, 64] : memref<64xf32> into memref<1x1x64xf32> + vector.store %val, %0[%c0, %c0, %arg1] : memref<1x1x64xf32>, vector<1x1x8xf32> + return +} + +// CHECK-LABEL: func @fold_vector_store_expand_shape_unit_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<64xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[VAL:[a-zA-Z0-9_]+]]: vector<1x1x8xf32> +// CHECK: %[[IDX:.*]] = affine.linearize_index [%{{.*}}, %{{.*}}, %[[ARG1]]] by (1, 1, 64) +// CHECK: %[[CAST:.*]] = vector.shape_cast %[[VAL]] : vector<1x1x8xf32> to vector<8xf32> +// CHECK: vector.store %[[CAST]], %[[ARG0]][%[[IDX]]] : memref<64xf32>, vector<8xf32> + +// ----- + +func.func @fold_vector_load_collapse_shape_with_unit_dim_vector( + %arg0 : memref<5x4x8x16xf32>, %arg1 : index, %arg2 : index) -> vector<1x1x4xf32> { + %c0 = arith.constant 0 : index + %0 = memref.collapse_shape %arg0 [[0], [1, 2], [3]] : memref<5x4x8x16xf32> into memref<5x32x16xf32> + %1 = vector.load %0[%arg1, %arg2, %c0] : memref<5x32x16xf32>, vector<1x1x4xf32> + return %1 : vector<1x1x4xf32> +} + +// CHECK-LABEL: func @fold_vector_load_collapse_shape_with_unit_dim_vector +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<5x4x8x16xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[IDXS:.*]]:2 = affine.delinearize_index %[[ARG2]] into (4, 8) +// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[IDXS]]#0, %[[IDXS]]#1, %[[C0]]] : memref<5x4x8x16xf32>, vector<1x1x4xf32> +// CHECK: return %[[LOAD]] + +// ----- + +func.func @fold_vector_store_collapse_shape_with_unit_dim_vector( + %arg0 : memref<5x4x8x16xf32>, %arg1 : index, %arg2 : index, %val : vector<1x1x4xf32>) { + %c0 = arith.constant 0 : index + %0 = memref.collapse_shape %arg0 [[0], [1, 2], [3]] : memref<5x4x8x16xf32> into memref<5x32x16xf32> + vector.store %val, %0[%arg1, %arg2, %c0] : memref<5x32x16xf32>, vector<1x1x4xf32> + return +} + +// CHECK-LABEL: func @fold_vector_store_collapse_shape_with_unit_dim_vector +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<5x4x8x16xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[VAL:[a-zA-Z0-9_]+]]: vector<1x1x4xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[IDXS:.*]]:2 = affine.delinearize_index %[[ARG2]] into (4, 8) +// CHECK: vector.store %[[VAL]], %[[ARG0]][%[[ARG1]], %[[IDXS]]#0, %[[IDXS]]#1, %[[C0]]] : memref<5x4x8x16xf32>, vector<1x1x4xf32> _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
