https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/96177
>From 276ed8981c5243696da3bf233a777e1b84f11131 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang <hsiangkai.w...@arm.com> Date: Mon, 17 Jun 2024 11:24:07 +0100 Subject: [PATCH 1/2] [mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation. The formula of Winograd Conv2D algorithm is Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A filter transform: G x g x G^T input transform: B^T x d x B output transform: A^T x y x A The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308) --- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 114 +++++++ .../Dialect/Linalg/Transforms/Transforms.h | 4 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 78 +++++ .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 + .../Linalg/Transforms/WinogradConv2D.cpp | 321 ++++++++++++++++++ mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 248 ++++++++++++++ .../Dialect/Linalg/TestLinalgTransforms.cpp | 13 + 7 files changed, 779 insertions(+) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 64c538367267d..de1097b6ac27b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", let hasVerifier = 1; } +def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> { + let summary = "Winograd filter transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of filter + transformation (G x g x G^T) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$filter, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $filter `:` type($filter) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> { + let summary = "Winograd input transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of input + transformation (B^T x d x B) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $input `:` type($input) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> { + let summary = "Winograd output transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of output + transformation (A^T x y x A) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$value, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $value `:` type($value) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 05e97befdec1f..835aeaf2ffed3 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); +/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r). +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 57d126603ebd7..7bf2a5bca037f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) { return SmallVector<Value>{result}; } +//===----------------------------------------------------------------------===// +// WinogradFilterTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradFilterTransformOp::verify() { + auto filterType = cast<ShapedType>(getFilter().getType()); + auto outputType = cast<ShapedType>(getOutput().getType()); + auto filterElemType = filterType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (filterElemType != outputElemType) { + return emitOpError() << "expected element type of input " << filterElemType + << " to match element type of output " + << outputElemType; + } + + unsigned filterRank = filterType.getRank(); + if (filterRank != 4) + return emitOpError() << "expected rank of input is 4"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 6) + return emitOpError() << "expected rank of output is 6"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradInputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradInputTransformOp::verify() { + auto inputType = cast<ShapedType>(getInput().getType()); + auto outputType = cast<ShapedType>(getOutput().getType()); + auto inputElemType = inputType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (inputElemType != outputElemType) { + return emitOpError() << "expected element type of input " << inputElemType + << " to match element type of output " + << outputElemType; + } + + unsigned inputRank = inputType.getRank(); + if (inputRank != 4) + return emitOpError() << "expected rank of input is 4"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 6) + return emitOpError() << "expected rank of output is 6"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradOutputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradOutputTransformOp::verify() { + auto valueType = cast<ShapedType>(getValue().getType()); + auto outputType = cast<ShapedType>(getOutput().getType()); + auto valueElemType = valueType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (valueElemType != outputElemType) { + return emitOpError() << "expected element type of value " << valueElemType + << " to match element type of output " + << outputElemType; + } + + unsigned valueRank = valueType.getRank(); + if (valueRank != 6) + return emitOpError() << "expected rank of input is 6"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 4) + return emitOpError() << "expected rank of output is 4"; + + return success(); +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 7e3dc56e0acdc..a7dcc29b5b9be 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Transforms.cpp TransposeConv2D.cpp Vectorization.cpp + WinogradConv2D.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp new file mode 100644 index 0000000000000..86e834d51f2fc --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -0,0 +1,321 @@ +//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implement Winograd Conv2D algorithm. The implementation is based on the +// paper: Fast Algorithms for Convolutional Neural Networks +// (https://arxiv.org/abs/1509.09308) +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace linalg { + +namespace { + +using TransformMapKeyTy = std::pair<int, int>; + +// We use F(m, r) to define the size of minimal filtering algorithms. +// m is the output dimension and r is the filter dimension. We can get +// the input dimension, alpha, from the formula, alpha = m + r - 1. +// +// For example, when m = 2 and r = 3, we know its input size is 4. +// The Conv2D will operate on 4x4 input data with 3x3 filter and get +// 2x2 output result. +constexpr TransformMapKeyTy F_2_3{2, 3}; +constexpr TransformMapKeyTy F_4_3{4, 3}; +constexpr TransformMapKeyTy F_2_5{2, 5}; + +Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { + auto type = cast<ShapedType>(data.getType()); + auto elementType = type.getElementType(); + auto shape = type.getShape(); + auto collapseType = RankedTensorType::get( + {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]}, + elementType); + SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}}; + return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data, + reassociation); +} + +// This function generates linalg.batch_matmul to multiply input with filter. +// linalg.batch_matmul only supports 3-dimension data sets. We can treat +// tileH x tileW x H x W data as the 1-dimension data array. That is to convert +// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we +// can convert 6-dimension input data to 3-dimension representation that is +// suitable for linalg.batch_matmul. +// +// Batched matmul will do the matrix multiply with the reduction on channel. +// +// We get +// +// %collapsed_input = tensor.collapse_shape %input +// %collapsed_filter = tensor.collapse_shape %filter +// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter +// %expanded_ret = tensor.expand_shape %ret +// +// After this function, we get return value with data layout +// (tileH, tileW, H, W, N, F). +Value matrixMultiply(RewriterBase &rewriter, Location loc, + Value transformedFilter, Value transformedInput) { + auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter); + auto collapseInput = collapse2DData(rewriter, loc, transformedInput); + + // Batched matrix multiply + auto filterType = cast<ShapedType>(transformedFilter.getType()); + auto filterShape = filterType.getShape(); + auto inputType = cast<ShapedType>(transformedInput.getType()); + auto inputElemType = inputType.getElementType(); + auto inputShape = inputType.getShape(); + + auto matmulType = RankedTensorType::get( + {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3], + inputShape[4], filterShape[5]}, + inputElemType); + Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), + inputElemType); + + auto matmulOp = rewriter.create<linalg::BatchMatmulOp>( + loc, matmulType, ValueRange({collapseInput, collapseFilter}), + ValueRange{init}); + + // Expand matmul result + SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}}; + auto expandType = + RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], + inputShape[3], inputShape[4], filterShape[5]}, + inputElemType); + auto expandOutput = rewriter.create<tensor::ExpandShapeOp>( + loc, expandType, matmulOp.getResult(0), reassociation); + return expandOutput; +} + +Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value, + RankedTensorType alignedType) { + Value alignedInput = rewriter.create<tensor::EmptyOp>( + loc, alignedType.getShape(), alignedType.getElementType()); + + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector<OpFoldResult, 4> offsets(4, zeroIndex); + SmallVector<OpFoldResult, 4> strides(4, oneIndex); + + auto valueType = cast<ShapedType>(value.getType()); + auto valueShape = valueType.getShape(); + SmallVector<OpFoldResult, 4> sizes; + sizes.emplace_back(rewriter.getIndexAttr(valueShape[0])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[1])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[2])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[3])); + + return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput, + offsets, sizes, strides); +} + +Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, + Value value, RankedTensorType extractedType) { + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector<OpFoldResult, 4> offsets(4, zeroIndex); + SmallVector<OpFoldResult, 4> strides(4, oneIndex); + + auto extractedShape = extractedType.getShape(); + SmallVector<OpFoldResult, 4> sizes; + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3])); + + return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value, + offsets, sizes, strides); +} + +bool hasAllOneValues(DenseIntElementsAttr attr) { + return llvm::all_of( + attr, [](const APInt &element) { return element.getSExtValue() == 1; }); +} + +FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp convOp, + int64_t m, int64_t r) { + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + auto inputType = cast<ShapedType>(input.getType()); + auto filterType = cast<ShapedType>(filter.getType()); + auto outputType = cast<ShapedType>(output.getType()); + + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(convOp, + "expected a static shape for the input"); + + if (!filterType.hasStaticShape()) + return rewriter.notifyMatchFailure( + convOp, "expected a static shape for the filter"); + + if (!hasAllOneValues(convOp.getDilations())) + return rewriter.notifyMatchFailure(convOp, + "expected all ones for dilations"); + + if (!hasAllOneValues(convOp.getStrides())) + return rewriter.notifyMatchFailure(convOp, "expected all ones for strides"); + + auto filterShape = filterType.getShape(); + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + auto inputShape = inputType.getShape(); + int64_t inputN = inputShape[0]; + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t inputC = inputShape[3]; + auto outputShape = outputType.getShape(); + int64_t outputN = outputShape[0]; + int64_t outputH = outputShape[1]; + int64_t outputW = outputShape[2]; + int64_t outputF = outputShape[3]; + + // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r) + bool isSupportedFilter = false; + if (filterH == filterW && filterH == r) + isSupportedFilter = true; + if (filterH == r && filterW == 1) + isSupportedFilter = true; + if (filterH == 1 && filterW == r) + isSupportedFilter = true; + + if (!isSupportedFilter) + return rewriter.notifyMatchFailure( + convOp, "only support filter (r x r), (r x 1) or (1 x r)"); + + // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5) + static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = { + F_2_3, F_4_3, F_2_5}; + + TransformMapKeyTy key = {m, r}; + auto it = std::find(validConfigs.begin(), validConfigs.end(), key); + // If we cannot find the constant transformation matrix, it means we do + // not support this configuration yet. + if (it == validConfigs.end()) + return failure(); + + // All the criterias are satisfied. We can do Winograd Conv2D. + Location loc = convOp.getLoc(); + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = filterH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = filterW != 1; + int64_t heightM = leftTransform ? m : 1; + int64_t widthM = rightTransform ? m : 1; + int64_t heightR = leftTransform ? r : 1; + int64_t widthR = rightTransform ? r : 1; + + // --- Create operator for filter transform --- + Type elementType = filterType.getElementType(); + int64_t alphaH = heightM + heightR - 1; + int64_t alphaW = widthM + widthR - 1; + int64_t tileH = llvm::divideCeilSigned(outputH, heightM); + int64_t tileW = llvm::divideCeilSigned(outputW, widthM); + auto retType = RankedTensorType::get( + {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType); + Value retValue = + rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType); + auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>( + loc, retType, filter, retValue, m, r); + + // --- Create operator for input transform --- + + // When input size - (r - 1) is not aligned with output tile size, we need to + // pad the input data to create the full tiles as tiling. + int64_t alignedInputH = tileH * heightM + (heightR - 1); + int64_t alignedInputW = tileW * widthM + (widthR - 1); + if (alignedInputH != inputH || alignedInputW != inputW) { + auto alignedInputType = RankedTensorType::get( + {inputN, alignedInputH, alignedInputW, inputC}, elementType); + input = insertToAlignedTensor(rewriter, loc, input, alignedInputType); + } + + retType = RankedTensorType::get( + {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType); + retValue = + rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType); + auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>( + loc, retType, input, retValue, m, r); + + Value matmulRet = + matrixMultiply(rewriter, loc, transformedFilter, transformedInput); + + // --- Create operator for output transform --- + + // When output size is not aligned with output tile size, we need to pad the + // output buffer to insert the full tiles after tiling. + int64_t alignedOutputH = tileH * heightM; + int64_t alignedOutputW = tileW * widthM; + bool isOutputUnaligned = + ((alignedOutputH != outputH) || (alignedOutputW != outputW)); + if (isOutputUnaligned) { + auto alignedOutputType = RankedTensorType::get( + {outputN, alignedOutputH, alignedOutputW, outputF}, elementType); + output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType); + outputType = alignedOutputType; + } + + Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>( + loc, outputType, matmulRet, output, m, r); + + // When output size is not aligned with output tile size, extract the + // value from the padded buffer. + if (isOutputUnaligned) { + transformedOutput = extractFromAlignedTensor( + rewriter, loc, transformedOutput, + RankedTensorType::get({outputN, outputH, outputW, outputF}, + elementType)); + } + + rewriter.replaceOp(convOp, transformedOutput); + + return transformedOutput.getDefiningOp(); +} + +class WinogradConv2DNhwcFhwc final + : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { +public: + using OpRewritePattern::OpRewritePattern; + WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) + : OpRewritePattern(context), m(m), r(r) {} + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, + PatternRewriter &rewriter) const override { + if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) + return failure(); + + return success(); + } + +private: + int64_t m; + int64_t r; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r) { + MLIRContext *context = patterns.getContext(); + patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r); +} + +} // end namespace linalg +} // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir new file mode 100644 index 0000000000000..6cca3c602d4c0 --- /dev/null +++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir @@ -0,0 +1,248 @@ +// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s + +func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %2 : tensor<2x4x4x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_4x4_3x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> { + %0 = tensor.empty() : tensor<2x2x2x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x2x2x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + return %2 : tensor<2x2x2x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_2x2_5x5 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x2x2x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x2x2x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> { + %0 = tensor.empty() : tensor<2x1x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x1x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> + return %2 : tensor<2x1x4x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_1x4_1x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x1x4x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x1x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> { + %0 = tensor.empty() : tensor<2x4x1x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x1x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> + return %2 : tensor<2x4x1x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_4x1_3x1 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x4x1x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x1x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %2 : tensor<2x8x8x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_aligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { + %0 = tensor.empty() : tensor<2x9x9x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x9x9x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %2 : tensor<2x9x9x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> +// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %2 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_1 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %2 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_2 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<2x3x3x5xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} + +// CHECK-LABEL: conv2d_unsupported_3 +// CHECK: linalg.conv_2d_nhwc_fhwc diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 4892fa2f99a7c..12cb46a5968f1 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -123,6 +123,10 @@ struct TestLinalgTransforms *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; + Option<bool> testWinogradConv2D{ + *this, "test-winograd-conv2d", + llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"), + llvm::cl::init(false)}; }; } // namespace @@ -207,6 +211,13 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyWinogradConv2D(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); + populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -231,6 +242,8 @@ void TestLinalgTransforms::runOnOperation() { return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); + if (testWinogradConv2D) + return applyWinogradConv2D(getOperation()); } namespace mlir { >From dc00c795540303c9a2a683b97dc9348a008763fb Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang <hsiangkai.w...@arm.com> Date: Mon, 17 Jun 2024 11:49:08 +0100 Subject: [PATCH 2/2] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm Add a transform operator structured.winograd_conv2d to convert linalg.conv_2d_nhwc_fhwc to Linalg winograd operators. --- .../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++ .../Dialect/Linalg/Transforms/Transforms.h | 7 ++ .../TransformOps/LinalgTransformOps.cpp | 25 ++++++ .../Linalg/Transforms/WinogradConv2D.cpp | 6 ++ .../Linalg/transform-winograd-conv2d.mlir | 88 +++++++++++++++++++ 5 files changed, 177 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 93e2c2db729da..68d0f713caad4 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp : }]; } +//===----------------------------------------------------------------------===// +// Winograd Conv2D +//===----------------------------------------------------------------------===// + +def WinogradConv2DOp : Op<Transform_Dialect, + "structured.winograd_conv2d", + [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + #### Return modes: + + This operation fails if `target` is unsupported. Otherwise, the operation + succeeds and returns a handle of the sequence that replaces the original + convolution. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$m, + I64Attr:$r); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 835aeaf2ffed3..da107b66257a5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS = true); +/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm +/// F(m x m, r x r). m is the dimension size of output and r is the dimension +/// size of filter. +FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r); + //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index bc02788f9c441..d051b29e1f06f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// WinogradConv2DOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch<Operation *, FailureOr<Operation *>>(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + return winogradConv2D(rewriter, op, getM(), getR()); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(op, "not supported"); + }); + + if (failed(maybeTransformed)) + return emitDefaultSilenceableFailure(target); + + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index 86e834d51f2fc..d1f4be8bbf29a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -311,6 +311,12 @@ class WinogradConv2DNhwcFhwc final } // end anonymous namespace //===----------------------------------------------------------------------===// +FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r) { + return winogradConv2DHelper(rewriter, op, m, r); +} + void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r) { MLIRContext *context = patterns.getContext(); diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir new file mode 100644 index 0000000000000..1e74fea5a1c31 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %2 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { + %0 = tensor.empty() : tensor<2x9x9x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x9x9x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %2 : tensor<2x9x9x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> +// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits