llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Hsiangkai Wang (Hsiangkai) <details> <summary>Changes</summary> --- Patch is 57.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96177.diff 10 Files Affected: - (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+114) - (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+51) - (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+11) - (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+78) - (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+25) - (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1) - (added) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+327) - (added) mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir (+88) - (added) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+248) - (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+13) ``````````diff diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 64c538367267d..de1097b6ac27b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", let hasVerifier = 1; } +def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> { + let summary = "Winograd filter transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of filter + transformation (G x g x G^T) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$filter, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $filter `:` type($filter) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> { + let summary = "Winograd input transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of input + transformation (B^T x d x B) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $input `:` type($input) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> { + let summary = "Winograd output transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of output + transformation (A^T x y x A) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$value, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $value `:` type($value) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 93e2c2db729da..68d0f713caad4 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp : }]; } +//===----------------------------------------------------------------------===// +// Winograd Conv2D +//===----------------------------------------------------------------------===// + +def WinogradConv2DOp : Op<Transform_Dialect, + "structured.winograd_conv2d", + [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + #### Return modes: + + This operation fails if `target` is unsupported. Otherwise, the operation + succeeds and returns a handle of the sequence that replaces the original + convolution. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$m, + I64Attr:$r); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 05e97befdec1f..da107b66257a5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS = true); +/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm +/// F(m x m, r x r). m is the dimension size of output and r is the dimension +/// size of filter. +FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r); + //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a @@ -1692,6 +1699,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); +/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r). +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 57d126603ebd7..7bf2a5bca037f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) { return SmallVector<Value>{result}; } +//===----------------------------------------------------------------------===// +// WinogradFilterTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradFilterTransformOp::verify() { + auto filterType = cast<ShapedType>(getFilter().getType()); + auto outputType = cast<ShapedType>(getOutput().getType()); + auto filterElemType = filterType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (filterElemType != outputElemType) { + return emitOpError() << "expected element type of input " << filterElemType + << " to match element type of output " + << outputElemType; + } + + unsigned filterRank = filterType.getRank(); + if (filterRank != 4) + return emitOpError() << "expected rank of input is 4"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 6) + return emitOpError() << "expected rank of output is 6"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradInputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradInputTransformOp::verify() { + auto inputType = cast<ShapedType>(getInput().getType()); + auto outputType = cast<ShapedType>(getOutput().getType()); + auto inputElemType = inputType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (inputElemType != outputElemType) { + return emitOpError() << "expected element type of input " << inputElemType + << " to match element type of output " + << outputElemType; + } + + unsigned inputRank = inputType.getRank(); + if (inputRank != 4) + return emitOpError() << "expected rank of input is 4"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 6) + return emitOpError() << "expected rank of output is 6"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradOutputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradOutputTransformOp::verify() { + auto valueType = cast<ShapedType>(getValue().getType()); + auto outputType = cast<ShapedType>(getOutput().getType()); + auto valueElemType = valueType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (valueElemType != outputElemType) { + return emitOpError() << "expected element type of value " << valueElemType + << " to match element type of output " + << outputElemType; + } + + unsigned valueRank = valueType.getRank(); + if (valueRank != 6) + return emitOpError() << "expected rank of input is 6"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 4) + return emitOpError() << "expected rank of output is 4"; + + return success(); +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index bc02788f9c441..d051b29e1f06f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// WinogradConv2DOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch<Operation *, FailureOr<Operation *>>(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + return winogradConv2D(rewriter, op, getM(), getR()); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(op, "not supported"); + }); + + if (failed(maybeTransformed)) + return emitDefaultSilenceableFailure(target); + + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 7e3dc56e0acdc..a7dcc29b5b9be 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Transforms.cpp TransposeConv2D.cpp Vectorization.cpp + WinogradConv2D.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp new file mode 100644 index 0000000000000..d1f4be8bbf29a --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -0,0 +1,327 @@ +//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implement Winograd Conv2D algorithm. The implementation is based on the +// paper: Fast Algorithms for Convolutional Neural Networks +// (https://arxiv.org/abs/1509.09308) +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace linalg { + +namespace { + +using TransformMapKeyTy = std::pair<int, int>; + +// We use F(m, r) to define the size of minimal filtering algorithms. +// m is the output dimension and r is the filter dimension. We can get +// the input dimension, alpha, from the formula, alpha = m + r - 1. +// +// For example, when m = 2 and r = 3, we know its input size is 4. +// The Conv2D will operate on 4x4 input data with 3x3 filter and get +// 2x2 output result. +constexpr TransformMapKeyTy F_2_3{2, 3}; +constexpr TransformMapKeyTy F_4_3{4, 3}; +constexpr TransformMapKeyTy F_2_5{2, 5}; + +Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { + auto type = cast<ShapedType>(data.getType()); + auto elementType = type.getElementType(); + auto shape = type.getShape(); + auto collapseType = RankedTensorType::get( + {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]}, + elementType); + SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}}; + return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data, + reassociation); +} + +// This function generates linalg.batch_matmul to multiply input with filter. +// linalg.batch_matmul only supports 3-dimension data sets. We can treat +// tileH x tileW x H x W data as the 1-dimension data array. That is to convert +// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we +// can convert 6-dimension input data to 3-dimension representation that is +// suitable for linalg.batch_matmul. +// +// Batched matmul will do the matrix multiply with the reduction on channel. +// +// We get +// +// %collapsed_input = tensor.collapse_shape %input +// %collapsed_filter = tensor.collapse_shape %filter +// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter +// %expanded_ret = tensor.expand_shape %ret +// +// After this function, we get return value with data layout +// (tileH, tileW, H, W, N, F). +Value matrixMultiply(RewriterBase &rewriter, Location loc, + Value transformedFilter, Value transformedInput) { + auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter); + auto collapseInput = collapse2DData(rewriter, loc, transformedInput); + + // Batched matrix multiply + auto filterType = cast<ShapedType>(transformedFilter.getType()); + auto filterShape = filterType.getShape(); + auto inputType = cast<ShapedType>(transformedInput.getType()); + auto inputElemType = inputType.getElementType(); + auto inputShape = inputType.getShape(); + + auto matmulType = RankedTensorType::get( + {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3], + inputShape[4], filterShape[5]}, + inputElemType); + Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), + inputElemType); + + auto matmulOp = rewriter.create<linalg::BatchMatmulOp>( + loc, matmulType, ValueRange({collapseInput, collapseFilter}), + ValueRange{init}); + + // Expand matmul result + SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}}; + auto expandType = + RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], + inputShape[3], inputShape[4], filterShape[5]}, + inputElemType); + auto expandOutput = rewriter.create<tensor::ExpandShapeOp>( + loc, expandType, matmulOp.getResult(0), reassociation); + return expandOutput; +} + +Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value, + RankedTensorType alignedType) { + Value alignedInput = rewriter.create<tensor::EmptyOp>( + loc, alignedType.getShape(), alignedType.getElementType()); + + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector<OpFoldResult, 4> offsets(4, zeroIndex); + SmallVector<OpFoldResult, 4> strides(4, oneIndex); + + auto valueType = cast<ShapedType>(value.getType()); + auto valueShape = valueType.getShape(); + SmallVector<OpFoldResult, 4> sizes; + sizes.emplace_back(rewriter.getIndexAttr(valueShape[0])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[1])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[2])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[3])); + + return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput, + offsets, sizes, strides); +} + +Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, + Value value, RankedTensorType extractedType) { + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector<OpFoldResult, 4> offsets(4, zeroIndex); + SmallVector<OpFoldResult, 4> strides(4, oneIndex); + + auto extractedShape = extractedType.getShape(); + SmallVector<OpFoldResult, 4> sizes; + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3])); + + return rewriter.create<tensor::ExtractSliceOp>(lo... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/96177 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits