[llvm-branch-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)
@@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold( }); return converged; } + +//===--===// +// One-Shot Dialect Conversion Infrastructure +//===--===// + +namespace { +/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter +/// immediately materializes all IR changes. It derives from +/// `ConversionPatternRewriter` so that the existing conversion patterns can +/// be used with the One-Shot Dialect Conversion. +class OneShotConversionPatternRewriter : public ConversionPatternRewriter { +public: + OneShotConversionPatternRewriter(MLIRContext *ctx) + : ConversionPatternRewriter(ctx) {} + + bool canRecoverFromRewriteFailure() const override { return false; } + + void replaceOp(Operation *op, ValueRange newValues) override; + + void replaceOp(Operation *op, Operation *newOp) override { +replaceOp(op, newOp->getResults()); + } + + void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); } + + void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); } + + void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, + ValueRange argValues = std::nullopt) override { +PatternRewriter::inlineBlockBefore(source, dest, before, argValues); + } + using PatternRewriter::inlineBlockBefore; + + void startOpModification(Operation *op) override { +PatternRewriter::startOpModification(op); + } + + void finalizeOpModification(Operation *op) override { +PatternRewriter::finalizeOpModification(op); + } + + void cancelOpModification(Operation *op) override { +PatternRewriter::cancelOpModification(op); + } + + void setCurrentTypeConverter(const TypeConverter *converter) override { +typeConverter = converter; + } + + const TypeConverter *getCurrentTypeConverter() const override { +return typeConverter; + } + + LogicalResult getAdapterOperands(StringRef valueDiagTag, + std::optional inputLoc, + ValueRange values, + SmallVector &remapped) override; + +private: + /// Build an unrealized_conversion_cast op or look it up in the cache. + Value buildUnrealizedConversionCast(Location loc, Type type, Value value); + + /// The current type converter. + const TypeConverter *typeConverter; + + /// A cache for unrealized_conversion_casts. To ensure that identical casts + /// are not built multiple times. + DenseMap, Value> castCache; ftynse wrote: Hmm, is it possible that the same original value is casted to multiple _different_ types within the same conversion? Type converter is currently unaware of the surrounding context, so it's unclear to me how that could happen. https://github.com/llvm/llvm-project/pull/93412 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)
@@ -1819,6 +1822,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { return *impl; } +void ConversionPatternRewriter::setCurrentTypeConverter( +const TypeConverter *converter) { + impl->currentTypeConverter = converter; +} + +const TypeConverter * +ConversionPatternRewriter::getCurrentTypeConverter() const { + return impl->currentTypeConverter; +} + +LogicalResult ConversionPatternRewriter::getAdapterOperands( ftynse wrote: Nit: can we agree on Adapter/Adaptor spelling? E.g., we already have `OpAdaptor`. https://github.com/llvm/llvm-project/pull/93412 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)
@@ -321,15 +323,15 @@ class RandomizedWorklist : public Worklist { /// to the worklist in the beginning. class GreedyPatternRewriteDriver : public RewriterBase::Listener { protected: - explicit GreedyPatternRewriteDriver(MLIRContext *ctx, + explicit GreedyPatternRewriteDriver(PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config); /// Add the given operation to the worklist. void addSingleOpToWorklist(Operation *op); /// Add the given operation and its ancestors to the worklist. - void addToWorklist(Operation *op); + virtual void addToWorklist(Operation *op); ftynse wrote: Do we have an estimate of how much overhead adding a vtable to this class introduces? https://github.com/llvm/llvm-project/pull/93412 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)
@@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold( }); return converged; } + +//===--===// +// One-Shot Dialect Conversion Infrastructure +//===--===// + +namespace { +/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter +/// immediately materializes all IR changes. It derives from +/// `ConversionPatternRewriter` so that the existing conversion patterns can +/// be used with the One-Shot Dialect Conversion. +class OneShotConversionPatternRewriter : public ConversionPatternRewriter { +public: + OneShotConversionPatternRewriter(MLIRContext *ctx) + : ConversionPatternRewriter(ctx) {} + + bool canRecoverFromRewriteFailure() const override { return false; } + + void replaceOp(Operation *op, ValueRange newValues) override; + + void replaceOp(Operation *op, Operation *newOp) override { +replaceOp(op, newOp->getResults()); + } + + void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); } + + void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); } + + void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, + ValueRange argValues = std::nullopt) override { +PatternRewriter::inlineBlockBefore(source, dest, before, argValues); + } + using PatternRewriter::inlineBlockBefore; + + void startOpModification(Operation *op) override { +PatternRewriter::startOpModification(op); + } + + void finalizeOpModification(Operation *op) override { +PatternRewriter::finalizeOpModification(op); + } + + void cancelOpModification(Operation *op) override { +PatternRewriter::cancelOpModification(op); + } ftynse wrote: Would these still be necessary after the old driver is removed? https://github.com/llvm/llvm-project/pull/93412 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm (PR #96182)
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===--===// +// WinogradConv2DOp +//===--===// + +DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( +transform::TransformRewriter &rewriter, linalg::LinalgOp target, +transform::ApplyToEachResultList &results, +transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch>(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { +return winogradConv2D(rewriter, op, getM(), getR()); + }) + .Default([&](Operation *op) { +return rewriter.notifyMatchFailure(op, "not supported"); ftynse wrote: Let's rather `emitSilenceableFailure()` with this message to the user. The rewriter messages are not printed AFAIK. https://github.com/llvm/llvm-project/pull/96182 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm (PR #96182)
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp : }]; } +//===--===// +// Winograd Conv2D +//===--===// + +def WinogradConv2DOp : Op { + let description = [{ +Winograd Conv2D algorithm will convert linalg Conv2D operator into batched ftynse wrote: ```suggestion Winograd Conv2D algorithm will convert linalg Conv2D operation into batched ``` Nit: these are called operations, not operators, in MLIR. https://github.com/llvm/llvm-project/pull/96182 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm (PR #96182)
@@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): +linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %2 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { +%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op +%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) +transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK:%[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } ftynse wrote: Since we are already testing the op production logic elsewhere, we don't need to re-test it here. It is sufficient to check that it worked at the high level, e.g.: ``` CHECK: winograd_filter_transform m(4) r(3) CHECK: winograd_input_transform CHECK: match_matmul CHECK: winograd_output_transform ``` https://github.com/llvm/llvm-project/pull/96182 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm (PR #96182)
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp : }]; } +//===--===// +// Winograd Conv2D +//===--===// + +def WinogradConv2DOp : Op { + let description = [{ +Winograd Conv2D algorithm will convert linalg Conv2D operator into batched +matrix multiply. Before the matrix multiply, it will convert filter and +input into a format suitable for batched matrix multiply. After the matrix +multiply, it will convert output to the final result tensor. + +The algorithm F(m x m, r x r) is + +Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + +The size of output Y is m x m. The size of filter g is r x r. The size of +input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are +transformation matrices. + + Return modes: + +This operation fails if `target` is unsupported. Otherwise, the operation ftynse wrote: ```suggestion This operation produces a silenceable failure if `target` is unsupported. Otherwise, the operation ``` https://github.com/llvm/llvm-project/pull/96182 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm (PR #96182)
@@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): +linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %2 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { +%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op +%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) +transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK:%[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// - + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { + %0 = tensor.empty() : tensor<2x9x9x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { + ^bb0(%in: f32, %out: f32): +linalg.yield %in : f32 + } -> tensor<2x9x9x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %2 : tensor<2x9x9x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { +%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op +%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) +transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_unaligned
[llvm-branch-commits] [mlir] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm (PR #96182)
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===--===// +// WinogradConv2DOp +//===--===// + +DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( +transform::TransformRewriter &rewriter, linalg::LinalgOp target, +transform::ApplyToEachResultList &results, +transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch>(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { +return winogradConv2D(rewriter, op, getM(), getR()); + }) + .Default([&](Operation *op) { +return rewriter.notifyMatchFailure(op, "not supported"); + }); + + if (failed(maybeTransformed)) +return emitDefaultSilenceableFailure(target); ftynse wrote: It would be nice if `windogradConv2D` was returning some error code or was taking a callback to print error messages to instead of giving a default message here. Non-blocking. https://github.com/llvm/llvm-project/pull/96182 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -36,6 +189,92 @@ constexpr TransformMapKeyTy F_2_3{2, 3}; constexpr TransformMapKeyTy F_4_3{4, 3}; constexpr TransformMapKeyTy F_2_5{2, 5}; +struct TransformMatrix { ftynse wrote: Please document top-level entities. https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -23,6 +26,156 @@ namespace linalg { namespace { +// clang-format off +// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its +// result. The formula of minimal 2D filtering algorithm F(m x m, r x r), +// m is the output dimension and r is the filter dimension, is +// +// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A +// +// g is filter and d is input data. We need to prepare 6 constant +// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula. +// +// The following tables define these constant transformation matrices for +// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5) +constexpr float G_2x2_3x3[] = { + -1, 0, 0, + 1./2, -1./2, 1./2, + 1./2, 1./2, 1./2, +0, 0,1 +}; + +constexpr float GT_2x2_3x3[] = { + -1, 1./2, 1./2, 0, +0, -1./2, 1./2, 0, +0, 1./2, 1./2, 1 +}; ftynse wrote: Have you considered introducing a (potentially `constexpr`) transpose function or some sort of transposed access iterator instead of hardcoding transposed matrices? https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { reassociation); } +// This function transforms the filter. The data layout of the filter is FHWC. +// The transformation matrix is 2-dimension. We need to extract H x W from +// FHWC first. We need to generate 2 levels of loops to iterate on F and C. +// After the transformation, we get +// +// scf.for %f = lo_f to hi_f step 1 +// scf.for %c = lo_c to hi_c step 1 +// %extracted = extract filter from filter +// %ret = linalg.matmul G, %extracted +// %ret = linalg.matmul %ret, GT +// %inserted = insert %ret into filter +// +Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, + Value retValue, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to G transform matrix. + static const llvm::SmallDenseMap + GMatrices = { + {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, + {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, + {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, + }; + + // Map from (m, r) to GT transform matrix. + static const llvm::SmallDenseMap + GTMatrices = { + {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, + {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, + {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, + }; + + auto filterType = cast(filter.getType()); + Type elementType = filterType.getElementType(); + auto filterShape = filterType.getShape(); // F, H, W, C + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + + if (filterH != r && filterH != 1) +return Value(); + if (filterW != r && filterW != 1) +return Value(); + + // Return shape is + auto zeroIdx = rewriter.create(loc, 0); + auto fUpperBound = rewriter.create(loc, filterF); + auto cUpperBound = rewriter.create(loc, filterC); + auto oneStep = rewriter.create(loc, 1); + auto outerForOp = + rewriter.create(loc, zeroIdx, fUpperBound, oneStep, retValue); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value FIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create( + loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); ftynse wrote: Ditto. there must be a better-named function for this. https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { reassociation); } +// This function transforms the filter. The data layout of the filter is FHWC. +// The transformation matrix is 2-dimension. We need to extract H x W from +// FHWC first. We need to generate 2 levels of loops to iterate on F and C. +// After the transformation, we get +// +// scf.for %f = lo_f to hi_f step 1 +// scf.for %c = lo_c to hi_c step 1 +// %extracted = extract filter from filter +// %ret = linalg.matmul G, %extracted +// %ret = linalg.matmul %ret, GT +// %inserted = insert %ret into filter +// +Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, + Value retValue, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to G transform matrix. + static const llvm::SmallDenseMap + GMatrices = { + {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, + {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, + {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, + }; + + // Map from (m, r) to GT transform matrix. + static const llvm::SmallDenseMap + GTMatrices = { + {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, + {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, + {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, + }; + + auto filterType = cast(filter.getType()); + Type elementType = filterType.getElementType(); + auto filterShape = filterType.getShape(); // F, H, W, C + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + + if (filterH != r && filterH != 1) +return Value(); + if (filterW != r && filterW != 1) +return Value(); + + // Return shape is + auto zeroIdx = rewriter.create(loc, 0); + auto fUpperBound = rewriter.create(loc, filterF); + auto cUpperBound = rewriter.create(loc, filterC); + auto oneStep = rewriter.create(loc, 1); + auto outerForOp = + rewriter.create(loc, zeroIdx, fUpperBound, oneStep, retValue); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value FIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create( + loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); + Block *innerForBody = innerForOp.getBody(); + rewriter.setInsertionPointToStart(innerForBody); + Value CIter = innerForBody->getArgument(0); + + // Extract (H, W) from (F, H, W, C) + auto extractFilter = extract2DData( + rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0, + /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + Value matmulRetValue = extractFilter; + if (leftTransform) { +// Get constant transform matrix G +auto it = GMatrices.find(key); +if (it == GMatrices.end()) + return Value(); +const TransformMatrix &GMatrix = it->second; + +retRows = GMatrix.rows; +auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); +auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + +Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType); ftynse wrote: I wonder if we rather want to provide these matrices as global memrefs instead of creating locally every time. Have you considered that? https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -323,5 +1089,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, patterns.insert(context, m, r); } +void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); ftynse wrote: ```suggestion patterns.insert(context); ``` https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc, return expandOutput; } +// This function transforms the output. The data layout of the output is HWNF. +// The transformation matrix is 2-dimension. We need to extract H x W from +// HWNF first. We need to generate 2 levels of loops to iterate on N and F. +// After the transformation, we get +// +// scf.for %n = lo_n to hi_n step 1 +// scf.for %f = lo_f to hi_f step 1 +// %extracted = extract input from result +// %ret = linalg.matmul AT, %extracted +// %ret = linalg.matmul %ret, A +// %inserted = insert %ret into ret +// +Value outputTransform(RewriterBase &rewriter, Location loc, Value value, + Value output, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to AT transform matrix. + static const llvm::SmallDenseMap + ATMatrices = { + {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, + {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, + {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, + }; + + // Map from (m, r) to A transform matrix. + static const llvm::SmallDenseMap + AMatrices = { + {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, + {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, + {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, + }; + + auto valueType = cast(value.getType()); + Type elementType = valueType.getElementType(); + auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F + int64_t valueH = valueShape[2]; + int64_t valueW = valueShape[3]; + int64_t valueN = valueShape[4]; + int64_t valueF = valueShape[5]; + int64_t alphaH = leftTransform ? m + r - 1 : 1; + int64_t alphaW = rightTransform ? m + r - 1 : 1; + + if (valueH != alphaH && valueH != 1) +return Value(); + if (valueW != alphaW && valueW != 1) +return Value(); + + auto zeroIdx = rewriter.create(loc, 0); + auto nUpperBound = rewriter.create(loc, valueN); + auto fUpperBound = rewriter.create(loc, valueF); + auto oneStep = rewriter.create(loc, 1); + + auto outerForOp = + rewriter.create(loc, zeroIdx, nUpperBound, oneStep, output); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value NIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create( + loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); + Block *innerForBody = innerForOp.getBody(); + rewriter.setInsertionPointToStart(innerForBody); + Value FIter = innerForBody->getArgument(0); + + // Extract (H, W) from (1, 1, H, W, N, F) + auto extractValue = extract2DData( + rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4, + /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + int64_t retCols = 1; + int64_t leftScalarFactor = 1; + int64_t rightScalarFactor = 1; + Value matmulRetValue = extractValue; + if (leftTransform) { +// Get constant transform matrix AT +auto it = ATMatrices.find(key); +if (it == ATMatrices.end()) + return Value(); +const TransformMatrix &ATMatrix = it->second; + +leftScalarFactor = ATMatrix.scalarFactor; +retRows = ATMatrix.rows; +auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); +auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + +Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType); +// Multiply AT x m +auto matmulOp = rewriter.create( +loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); +matmulRetValue = matmulOp.getResult(0); + } + + if (rightTransform) { +// Get constant transform matrix T +auto it = AMatrices.find(key); +if (it == AMatrices.end()) + return Value(); +const TransformMatrix &AMatrix = it->second; + +rightScalarFactor = AMatrix.scalarFactor; +auto matmulType = +RankedTensorType::get({retRows, AMatrix.cols}, elementType); +retCols = AMatrix.cols; +auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + +Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType); +// Multiply y = (AT x m) x A +auto matmulOp = rewriter.create( +loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); +matmulRetValue = matmulOp.getResult(0); + } + + // Multiply scalar factor. + Value scalarFactor = rewriter.create( + loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor)); + auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); + auto init = + rewriter.create(loc, matmulType.getShape(), elementType); + + auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); + Smal
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc, return expandOutput; } +// This function transforms the output. The data layout of the output is HWNF. +// The transformation matrix is 2-dimension. We need to extract H x W from +// HWNF first. We need to generate 2 levels of loops to iterate on N and F. +// After the transformation, we get +// +// scf.for %n = lo_n to hi_n step 1 +// scf.for %f = lo_f to hi_f step 1 +// %extracted = extract input from result +// %ret = linalg.matmul AT, %extracted +// %ret = linalg.matmul %ret, A +// %inserted = insert %ret into ret +// +Value outputTransform(RewriterBase &rewriter, Location loc, Value value, + Value output, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to AT transform matrix. + static const llvm::SmallDenseMap + ATMatrices = { + {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, + {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, + {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, + }; + + // Map from (m, r) to A transform matrix. + static const llvm::SmallDenseMap + AMatrices = { + {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, + {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, + {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, + }; + + auto valueType = cast(value.getType()); + Type elementType = valueType.getElementType(); + auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F + int64_t valueH = valueShape[2]; + int64_t valueW = valueShape[3]; + int64_t valueN = valueShape[4]; + int64_t valueF = valueShape[5]; + int64_t alphaH = leftTransform ? m + r - 1 : 1; + int64_t alphaW = rightTransform ? m + r - 1 : 1; + + if (valueH != alphaH && valueH != 1) +return Value(); + if (valueW != alphaW && valueW != 1) +return Value(); + + auto zeroIdx = rewriter.create(loc, 0); + auto nUpperBound = rewriter.create(loc, valueN); + auto fUpperBound = rewriter.create(loc, valueF); + auto oneStep = rewriter.create(loc, 1); + + auto outerForOp = + rewriter.create(loc, zeroIdx, nUpperBound, oneStep, output); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value NIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create( + loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); + Block *innerForBody = innerForOp.getBody(); + rewriter.setInsertionPointToStart(innerForBody); + Value FIter = innerForBody->getArgument(0); + + // Extract (H, W) from (1, 1, H, W, N, F) + auto extractValue = extract2DData( + rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4, + /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + int64_t retCols = 1; + int64_t leftScalarFactor = 1; + int64_t rightScalarFactor = 1; + Value matmulRetValue = extractValue; + if (leftTransform) { +// Get constant transform matrix AT +auto it = ATMatrices.find(key); +if (it == ATMatrices.end()) + return Value(); +const TransformMatrix &ATMatrix = it->second; + +leftScalarFactor = ATMatrix.scalarFactor; +retRows = ATMatrix.rows; +auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); +auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + +Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType); +// Multiply AT x m +auto matmulOp = rewriter.create( +loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); +matmulRetValue = matmulOp.getResult(0); + } + + if (rightTransform) { +// Get constant transform matrix T +auto it = AMatrices.find(key); +if (it == AMatrices.end()) + return Value(); +const TransformMatrix &AMatrix = it->second; + +rightScalarFactor = AMatrix.scalarFactor; +auto matmulType = +RankedTensorType::get({retRows, AMatrix.cols}, elementType); +retCols = AMatrix.cols; +auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + +Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType); +// Multiply y = (AT x m) x A +auto matmulOp = rewriter.create( +loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); +matmulRetValue = matmulOp.getResult(0); + } + + // Multiply scalar factor. + Value scalarFactor = rewriter.create( + loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor)); + auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); + auto init = + rewriter.create(loc, matmulType.getShape(), elementType); + + auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); + Smal
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc, return expandOutput; } +// This function transforms the output. The data layout of the output is HWNF. +// The transformation matrix is 2-dimension. We need to extract H x W from +// HWNF first. We need to generate 2 levels of loops to iterate on N and F. +// After the transformation, we get +// +// scf.for %n = lo_n to hi_n step 1 +// scf.for %f = lo_f to hi_f step 1 +// %extracted = extract input from result +// %ret = linalg.matmul AT, %extracted +// %ret = linalg.matmul %ret, A +// %inserted = insert %ret into ret +// +Value outputTransform(RewriterBase &rewriter, Location loc, Value value, + Value output, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to AT transform matrix. + static const llvm::SmallDenseMap + ATMatrices = { + {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, + {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, + {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, + }; + + // Map from (m, r) to A transform matrix. + static const llvm::SmallDenseMap + AMatrices = { + {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, + {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, + {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, + }; + + auto valueType = cast(value.getType()); + Type elementType = valueType.getElementType(); + auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F + int64_t valueH = valueShape[2]; + int64_t valueW = valueShape[3]; + int64_t valueN = valueShape[4]; + int64_t valueF = valueShape[5]; + int64_t alphaH = leftTransform ? m + r - 1 : 1; + int64_t alphaW = rightTransform ? m + r - 1 : 1; + + if (valueH != alphaH && valueH != 1) +return Value(); + if (valueW != alphaW && valueW != 1) +return Value(); + + auto zeroIdx = rewriter.create(loc, 0); + auto nUpperBound = rewriter.create(loc, valueN); + auto fUpperBound = rewriter.create(loc, valueF); + auto oneStep = rewriter.create(loc, 1); + + auto outerForOp = + rewriter.create(loc, zeroIdx, nUpperBound, oneStep, output); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value NIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create( + loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); + Block *innerForBody = innerForOp.getBody(); + rewriter.setInsertionPointToStart(innerForBody); + Value FIter = innerForBody->getArgument(0); ftynse wrote: FYI, there's a `mlir::scf::buildLoopNest` somewhere that may space you the boilerplate. https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { reassociation); } +// This function transforms the filter. The data layout of the filter is FHWC. +// The transformation matrix is 2-dimension. We need to extract H x W from +// FHWC first. We need to generate 2 levels of loops to iterate on F and C. +// After the transformation, we get +// +// scf.for %f = lo_f to hi_f step 1 +// scf.for %c = lo_c to hi_c step 1 +// %extracted = extract filter from filter +// %ret = linalg.matmul G, %extracted +// %ret = linalg.matmul %ret, GT +// %inserted = insert %ret into filter +// ftynse wrote: ```suggestion /// This function transforms the filter. The data layout of the filter is FHWC. /// The transformation matrix is 2-dimension. We need to extract H x W from /// FHWC first. We need to generate 2 levels of loops to iterate on F and C. /// After the transformation, we get /// /// scf.for %f = lo_f to hi_f step 1 /// scf.for %c = lo_c to hi_c step 1 /// %extracted = extract filter from filter /// %ret = linalg.matmul G, %extracted /// %ret = linalg.matmul %ret, GT /// %inserted = insert %ret into filter /// ``` https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -289,6 +938,123 @@ FailureOr winogradConv2DHelper(RewriterBase &rewriter, return transformedOutput.getDefiningOp(); } +FailureOr +decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, + linalg::WinogradFilterTransformOp op) { + Location loc = op.getLoc(); + Value filter = op.getFilter(); + auto filterType = cast(filter.getType()); + auto filterShape = filterType.getShape(); + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = filterH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = filterW != 1; + Value transformedFilter = + filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedFilter) +return failure(); + + rewriter.replaceOp(op, transformedFilter); + + return transformedFilter.getDefiningOp(); +} + +FailureOr +decomposeWinogradInputTransformHelper(RewriterBase &rewriter, + linalg::WinogradInputTransformOp op) { + Location loc = op.getLoc(); + Value input = op.getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = inputH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = inputW != 1; + Value transformedInput = + inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedInput) +return failure(); + + rewriter.replaceOp(op, transformedInput); + + return transformedInput.getDefiningOp(); +} + +FailureOr +decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, + linalg::WinogradOutputTransformOp op) { + Location loc = op.getLoc(); + Value value = op.getValue(); + auto valueType = cast(value.getType()); + auto valueShape = valueType.getShape(); + int64_t valueH = valueShape[2]; + int64_t valueW = valueShape[3]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = valueH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = valueW != 1; + Value transformedOutput = + outputTransform(rewriter, loc, value, op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedOutput) +return failure(); + + rewriter.replaceOp(op, transformedOutput); + + return transformedOutput.getDefiningOp(); +} + +class DecomposeWinogradFilterTransform final +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op, +PatternRewriter &rewriter) const override { +if (failed(decomposeWinogradFilterTransformHelper(rewriter, op))) + return failure(); + +return success(); + } +}; + +class DecomposeWinogradInputTransform final +: public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op, +PatternRewriter &rewriter) const override { +if (failed(decomposeWinogradInputTransformHelper(rewriter, op))) + return failure(); + +return success(); ftynse wrote: ```suggestion return decomposeWinogradInputTransformHelper(rewriter, op); ``` https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -36,6 +189,92 @@ constexpr TransformMapKeyTy F_2_3{2, 3}; constexpr TransformMapKeyTy F_4_3{4, 3}; constexpr TransformMapKeyTy F_2_5{2, 5}; +struct TransformMatrix { + TransformMatrix(const float *table, int64_t rows, int64_t cols, + int64_t scalarFactor = 1) + : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} + + const float *table; + int64_t rows; + int64_t cols; + int64_t scalarFactor; +}; + +Value create2DTransformMatrix(RewriterBase &rewriter, Location loc, + TransformMatrix transform, Type type) { + ArrayRef const_vec(transform.table, transform.rows * transform.cols); ftynse wrote: Nit: camelBack ```suggestion ArrayRef constVec(transform.table, transform.rows * transform.cols); ``` https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { reassociation); } +// This function transforms the filter. The data layout of the filter is FHWC. +// The transformation matrix is 2-dimension. We need to extract H x W from +// FHWC first. We need to generate 2 levels of loops to iterate on F and C. +// After the transformation, we get +// +// scf.for %f = lo_f to hi_f step 1 +// scf.for %c = lo_c to hi_c step 1 +// %extracted = extract filter from filter +// %ret = linalg.matmul G, %extracted +// %ret = linalg.matmul %ret, GT +// %inserted = insert %ret into filter +// +Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, + Value retValue, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to G transform matrix. + static const llvm::SmallDenseMap + GMatrices = { + {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, + {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, + {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, + }; + + // Map from (m, r) to GT transform matrix. + static const llvm::SmallDenseMap + GTMatrices = { + {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, + {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, + {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, + }; + + auto filterType = cast(filter.getType()); + Type elementType = filterType.getElementType(); + auto filterShape = filterType.getShape(); // F, H, W, C + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + + if (filterH != r && filterH != 1) +return Value(); + if (filterW != r && filterW != 1) +return Value(); + + // Return shape is + auto zeroIdx = rewriter.create(loc, 0); + auto fUpperBound = rewriter.create(loc, filterF); + auto cUpperBound = rewriter.create(loc, filterC); + auto oneStep = rewriter.create(loc, 1); + auto outerForOp = + rewriter.create(loc, zeroIdx, fUpperBound, oneStep, retValue); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value FIter = outerForBody->getArgument(0); ftynse wrote: There must be a function on `scf::ForOp` that returns the induction variable and avoids magic constant zero here. https://github.com/llvm/llvm-project/pull/96183 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
https://github.com/ftynse commented: I think @MaheshRavishankar should take a look at the interface implementation details. https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
https://github.com/ftynse edited https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2760,6 +2760,89 @@ LogicalResult WinogradFilterTransformOp::verify() { return success(); } +SmallVector +WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value output = getOutput(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { +loopBounds[dim].offset = zero; +loopBounds[dim].size = getDimValue(builder, loc, output, dim); +loopBounds[dim].stride = one; + } + return loopBounds; +} + +SmallVector +WinogradFilterTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder, + Location loc) { + if (auto val = opFoldResult.dyn_cast()) { +return val; + } else if (auto attr = opFoldResult.dyn_cast()) { +auto intAttr = cast(attr); +return builder.create(loc, intAttr); + } ftynse wrote: I suspect this might already exist somewhere in the arith dialect. https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2638,4 +2638,41 @@ def WinogradConv2DOp : Op { + let description = [{ +Decompose winograd operators. It will convert filter, input and output +transform operators into a combination of scf, tensor, and linalg ftynse wrote: Nit: operations https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2760,6 +2760,89 @@ LogicalResult WinogradFilterTransformOp::verify() { return success(); } +SmallVector +WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value output = getOutput(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { +loopBounds[dim].offset = zero; +loopBounds[dim].size = getDimValue(builder, loc, output, dim); +loopBounds[dim].stride = one; + } + return loopBounds; +} + +SmallVector +WinogradFilterTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder, + Location loc) { + if (auto val = opFoldResult.dyn_cast()) { +return val; + } else if (auto attr = opFoldResult.dyn_cast()) { +auto intAttr = cast(attr); +return builder.create(loc, intAttr); + } + // This should never happen if OpFoldResult is correctly formed. ftynse wrote: Then this should be an assertion. https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2760,6 +2760,89 @@ LogicalResult WinogradFilterTransformOp::verify() { return success(); } +SmallVector +WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); ftynse wrote: IIRC, `Range` contains list of `OpFoldResult`, meaning we can put attributes there and not materialize operations for these constants. https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (PR #96329)
https://github.com/ftynse approved this pull request. https://github.com/llvm/llvm-project/pull/96329 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() { if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } + return success(); } +SmallVector +WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + auto indexType = builder.getIndexType(); + auto zeroAttr = builder.getIntegerAttr(indexType, 0); + auto oneAttr = builder.getIntegerAttr(indexType, 1); + Value output = getOutput(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { +loopBounds[dim].offset = zeroAttr; +loopBounds[dim].size = getDimValue(builder, loc, output, dim); +loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector +WinogradInputTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradInputTransformOp::getResultTilePosition( +OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, +ArrayRef sizes, SmallVector &resultOffsets, +SmallVector &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(offsets[2]); + resultOffsets.push_back(offsets[3]); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(sizes[0]); + resultSizes.push_back(sizes[1]); + resultSizes.push_back(oneAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(sizes[5]); + + return success(); +} + +FailureOr +WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes) { + auto oneAttr = builder.getI64IntegerAttr(1); + auto zeroAttr = builder.getI64IntegerAttr(0); + Value input = getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); ftynse wrote: Please expand `auto` unless the type is obvious from statement-level context (builders on the RHS are fine, but I don't remember what `getShape` returns as a type). https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() { if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } + return success(); } +SmallVector +WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + auto indexType = builder.getIndexType(); + auto zeroAttr = builder.getIntegerAttr(indexType, 0); + auto oneAttr = builder.getIntegerAttr(indexType, 1); + Value output = getOutput(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { +loopBounds[dim].offset = zeroAttr; +loopBounds[dim].size = getDimValue(builder, loc, output, dim); +loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector +WinogradInputTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradInputTransformOp::getResultTilePosition( +OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, +ArrayRef sizes, SmallVector &resultOffsets, +SmallVector &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(offsets[2]); + resultOffsets.push_back(offsets[3]); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(sizes[0]); + resultSizes.push_back(sizes[1]); + resultSizes.push_back(oneAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(sizes[5]); ftynse wrote: Nit: something like `resultOffsets.append({zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr})` may be more readable. https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
https://github.com/ftynse edited https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() { if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } + return success(); } +SmallVector +WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + auto indexType = builder.getIndexType(); + auto zeroAttr = builder.getIntegerAttr(indexType, 0); + auto oneAttr = builder.getIntegerAttr(indexType, 1); + Value output = getOutput(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { +loopBounds[dim].offset = zeroAttr; +loopBounds[dim].size = getDimValue(builder, loc, output, dim); +loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector +WinogradInputTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradInputTransformOp::getResultTilePosition( +OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, +ArrayRef sizes, SmallVector &resultOffsets, +SmallVector &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(offsets[2]); + resultOffsets.push_back(offsets[3]); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(sizes[0]); + resultSizes.push_back(sizes[1]); + resultSizes.push_back(oneAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(sizes[5]); + + return success(); +} + +FailureOr +WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes) { + auto oneAttr = builder.getI64IntegerAttr(1); + auto zeroAttr = builder.getI64IntegerAttr(0); + Value input = getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t m = getM(); + int64_t r = getR(); + int64_t alpha = m + r - 1; + int64_t alphaH = inputH != 1 ? alpha : 1; + int64_t alphaW = inputW != 1 ? alpha : 1; + auto alphaHAttr = builder.getI64IntegerAttr(alphaH); + auto alphaWAttr = builder.getI64IntegerAttr(alphaW); + + Location loc = getLoc(); + SmallVector tiledOperands; + SmallVector sliceOffsets, sliceSizes; + + auto context = builder.getContext(); + auto affineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + Value mappedOffset1 = builder.create( + loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc)); + Value mappedOffset2 = builder.create( + loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc)); + + sliceOffsets.push_back(zeroAttr); + sliceOffsets.push_back(mappedOffset1); + sliceOffsets.push_back(mappedOffset2); + sliceOffsets.push_back(zeroAttr); + sliceSizes.push_back(sizes[4]); + sliceSizes.push_back(alphaHAttr); + sliceSizes.push_back(alphaWAttr); + sliceSizes.push_back(sizes[5]); + SmallVector inputStrides(4, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getInput(), sliceOffsets, sliceSizes, inputStrides)); + + sliceOffsets.clear(); + sliceSizes.clear(); ftynse wrote: I'd rather declare new vectors for this. https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() { if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } + return success(); } +SmallVector +WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + auto indexType = builder.getIndexType(); + auto zeroAttr = builder.getIntegerAttr(indexType, 0); + auto oneAttr = builder.getIntegerAttr(indexType, 1); + Value output = getOutput(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { +loopBounds[dim].offset = zeroAttr; +loopBounds[dim].size = getDimValue(builder, loc, output, dim); +loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector +WinogradInputTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradInputTransformOp::getResultTilePosition( +OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, +ArrayRef sizes, SmallVector &resultOffsets, +SmallVector &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(offsets[2]); + resultOffsets.push_back(offsets[3]); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(sizes[0]); + resultSizes.push_back(sizes[1]); + resultSizes.push_back(oneAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(sizes[5]); + + return success(); +} + +FailureOr +WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes) { + auto oneAttr = builder.getI64IntegerAttr(1); + auto zeroAttr = builder.getI64IntegerAttr(0); + Value input = getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t m = getM(); + int64_t r = getR(); + int64_t alpha = m + r - 1; + int64_t alphaH = inputH != 1 ? alpha : 1; + int64_t alphaW = inputW != 1 ? alpha : 1; + auto alphaHAttr = builder.getI64IntegerAttr(alphaH); + auto alphaWAttr = builder.getI64IntegerAttr(alphaW); + + Location loc = getLoc(); + SmallVector tiledOperands; + SmallVector sliceOffsets, sliceSizes; + + auto context = builder.getContext(); + auto affineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + Value mappedOffset1 = builder.create( + loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc)); + Value mappedOffset2 = builder.create( + loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc)); + + sliceOffsets.push_back(zeroAttr); + sliceOffsets.push_back(mappedOffset1); + sliceOffsets.push_back(mappedOffset2); + sliceOffsets.push_back(zeroAttr); + sliceSizes.push_back(sizes[4]); + sliceSizes.push_back(alphaHAttr); + sliceSizes.push_back(alphaWAttr); + sliceSizes.push_back(sizes[5]); + SmallVector inputStrides(4, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getInput(), sliceOffsets, sliceSizes, inputStrides)); + + sliceOffsets.clear(); + sliceSizes.clear(); + if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets, + sliceSizes))) +return failure(); + + SmallVector outputStrides(6, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getOutput(), sliceOffsets, sliceSizes, outputStrides)); + + SmallVector resultTypes; ftynse wrote: ```suggestion SmallVector resultTypes; ``` https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
https://github.com/ftynse commented: Looks okay to me in general. Something went wrong with rebases so I see code that doesn't belong to this change. Let me know when you merged the bases and I can click-approve this one. https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2776,6 +2776,15 @@ LogicalResult WinogradFilterTransformOp::verify() { // WinogradInputTransformOp //===--===// +Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder, + Location loc) { + if (auto attr = opFoldResult.dyn_cast()) { +auto intAttr = cast(attr); +return builder.create(loc, intAttr); + } + return opFoldResult.get(); +} ftynse wrote: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/Utils/Utils.h#L68-L72 it already exists https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() { if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } + return success(); } +SmallVector +WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + auto indexType = builder.getIndexType(); + auto zeroAttr = builder.getIntegerAttr(indexType, 0); + auto oneAttr = builder.getIntegerAttr(indexType, 1); + Value output = getOutput(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { +loopBounds[dim].offset = zeroAttr; +loopBounds[dim].size = getDimValue(builder, loc, output, dim); +loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector +WinogradInputTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradInputTransformOp::getResultTilePosition( +OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, +ArrayRef sizes, SmallVector &resultOffsets, +SmallVector &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(offsets[2]); + resultOffsets.push_back(offsets[3]); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(sizes[0]); + resultSizes.push_back(sizes[1]); + resultSizes.push_back(oneAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(sizes[5]); + + return success(); +} + +FailureOr +WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes) { + auto oneAttr = builder.getI64IntegerAttr(1); + auto zeroAttr = builder.getI64IntegerAttr(0); + Value input = getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); ftynse wrote: Here and below. https://github.com/llvm/llvm-project/pull/96184 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][Transform] `apply_conversion_patterns`: Update handles (PR #83950)
@@ -632,7 +663,11 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() { void transform::ApplyConversionPatternsOp::getEffects( SmallVectorImpl &effects) { - transform::consumesHandle(getTarget(), effects); + if (!getPreserveHandles()) { +transform::consumesHandle(getTarget(), effects); + } else { +transform::onlyReadsHandle(getTarget(), effects); + } ftynse wrote: Nit: I don't recall if dialect conversion could rewrite the top-level op or not. If it can, it may need to still consume the handle... https://github.com/llvm/llvm-project/pull/83950 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][Transform] `apply_conversion_patterns`: Update handles (PR #83950)
https://github.com/ftynse edited https://github.com/llvm/llvm-project/pull/83950 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][Transform] `apply_conversion_patterns`: Update handles (PR #83950)
https://github.com/ftynse approved this pull request. https://github.com/llvm/llvm-project/pull/83950 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][Transform] `apply_conversion_patterns`: Update handles (PR #83950)
@@ -190,19 +190,29 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns", The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects` attributes specify the conversion target. -This transform consumes the `target` handle and modifies the payload. It -does not produce any handles. +This transform modifies the payload. By default, it consumes the `target` +handle. It does not produce any handles. + +If the `preserve_handles` attribute is set, this transform does not consume +the `target` handle and instead updates handles based on notifications from +a tracking listener that is attached to the dialect conversion, similar to +`transform.apply_patterns`. Only replacements via `RewriterBase::replaceOp` +or `replaceOpWithNewOp` are considered "payload op replacements". In +contrast to `transform.apply_patterns`, we allow replacement ops even if the ftynse wrote: Nit: could you provide the rationale as to why op name change is allowed here? Presumably because conversion is expected to intentionally change names to another dialect. https://github.com/llvm/llvm-project/pull/83950 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] release/18.x: [MLIR] [Transforms] Let `transform.structured.convert_to_loops` return handles to loops (#83984) (PR #85942)
ftynse wrote: https://github.com/llvm/llvm-project/commit/0597644a6466ae9148b0b41cb8f95d5022e045c2 looks like a bugfix, but https://github.com/llvm/llvm-project/commit/47bc565ca7990a2de20af4030baf08ac62739aca is a arguably a new feature and likely should not be backported. What is the reason for backporting the latter? https://github.com/llvm/llvm-project/pull/85942 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] release/18.x: [mlir][transform] replace original op to loop ops (#83537) (PR #87080)
https://github.com/ftynse approved this pull request. https://github.com/llvm/llvm-project/pull/87080 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits