@@ -2776,6 +2776,15 @@ LogicalResult WinogradFilterTransformOp::verify() {
// WinogradInputTransformOp
//===--===//
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() {
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
+
return success();
}
+SmallVector
+WinogradInputTransformOp
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() {
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
+
return success();
}
+SmallVector
+WinogradInputTransformOp
https://github.com/ftynse commented:
Looks okay to me in general. Something went wrong with rebases so I see code
that doesn't belong to this change. Let me know when you merged the bases and I
can click-approve this one.
https://github.com/llvm/llvm-project/pull/96184
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() {
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
+
return success();
}
+SmallVector
+WinogradInputTransformOp
https://github.com/ftynse edited https://github.com/llvm/llvm-project/pull/96184
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() {
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
+
return success();
}
+SmallVector
+WinogradInputTransformOp
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() {
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
+
return success();
}
+SmallVector
+WinogradInputTransformOp
https://github.com/ftynse approved this pull request.
https://github.com/llvm/llvm-project/pull/96329
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
@@ -2638,4 +2638,41 @@ def WinogradConv2DOp : Op {
+ let description = [{
+Decompose winograd operators. It will convert filter, input and output
+transform operators into a combination of scf, tensor, and linalg
ftynse wrote:
Nit: operations
https://gi
@@ -2760,6 +2760,89 @@ LogicalResult WinogradFilterTransformOp::verify() {
return success();
}
+SmallVector
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create(loc, 0);
+ Value one = builder.create(
@@ -2760,6 +2760,89 @@ LogicalResult WinogradFilterTransformOp::verify() {
return success();
}
+SmallVector
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create(loc, 0);
+ Value one = builder.create(
@@ -2760,6 +2760,89 @@ LogicalResult WinogradFilterTransformOp::verify() {
return success();
}
+SmallVector
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create(loc, 0);
+ Value one = builder.create(
https://github.com/ftynse commented:
I think @MaheshRavishankar should take a look at the interface implementation
details.
https://github.com/llvm/llvm-project/pull/96184
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https:/
https://github.com/ftynse edited https://github.com/llvm/llvm-project/pull/96184
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
@@ -36,6 +189,92 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
constexpr TransformMapKeyTy F_4_3{4, 3};
constexpr TransformMapKeyTy F_2_5{2, 5};
+struct TransformMatrix {
+ TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ int64_t scalarFactor =
@@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
return expandOutput;
}
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first.
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc,
Value data) {
reassociation);
}
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2
@@ -289,6 +938,123 @@ FailureOr winogradConv2DHelper(RewriterBase
&rewriter,
return transformedOutput.getDefiningOp();
}
+FailureOr
+decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
+ linalg::WinogradFilterTransformOp op)
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc,
Value data) {
reassociation);
}
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc,
Value data) {
reassociation);
}
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2
@@ -323,5 +1089,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet
&patterns, int64_t m,
patterns.insert(context, m, r);
}
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.insert(
@@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
return expandOutput;
}
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first.
@@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
return expandOutput;
}
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first.
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc,
Value data) {
reassociation);
}
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2
@@ -36,6 +189,92 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
constexpr TransformMapKeyTy F_4_3{4, 3};
constexpr TransformMapKeyTy F_2_5{2, 5};
+struct TransformMatrix {
ftynse wrote:
Please document top-level entities.
https://github.com/llvm/llvm-project/pu
@@ -23,6 +26,156 @@ namespace linalg {
namespace {
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dime
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure
transform::MapCopyToThreadsOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===--===//
+// WinogradConv2DOp
+//===--
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
}];
}
+//===--===//
+// Winograd Conv2D
+//===--===//
+
+def WinogradConv2DOp : Op {
+ let
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file |
FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>,
%arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = tensor.empty() : tensor<2x8x8x2xf32>
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file |
FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>,
%arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = tensor.empty() : tensor<2x8x8x2xf32>
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure
transform::MapCopyToThreadsOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===--===//
+// WinogradConv2DOp
+//===--
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
}];
}
+//===--===//
+// Winograd Conv2D
+//===--===//
+
+def WinogradConv2DOp : Op {
+ let
@@ -321,15 +323,15 @@ class RandomizedWorklist : public Worklist {
/// to the worklist in the beginning.
class GreedyPatternRewriteDriver : public RewriterBase::Listener {
protected:
- explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
+ explicit GreedyPatternRewriteDriver
@@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold(
});
return converged;
}
+
+//===--===//
+// One-Shot Dialect Conversion Infrastructure
+//===-
@@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold(
});
return converged;
}
+
+//===--===//
+// One-Shot Dialect Conversion Infrastructure
+//===-
@@ -1819,6 +1822,22 @@ detail::ConversionPatternRewriterImpl
&ConversionPatternRewriter::getImpl() {
return *impl;
}
+void ConversionPatternRewriter::setCurrentTypeConverter(
+const TypeConverter *converter) {
+ impl->currentTypeConverter = converter;
+}
+
+const TypeC
https://github.com/ftynse approved this pull request.
https://github.com/llvm/llvm-project/pull/87080
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
ftynse wrote:
https://github.com/llvm/llvm-project/commit/0597644a6466ae9148b0b41cb8f95d5022e045c2
looks like a bugfix, but
https://github.com/llvm/llvm-project/commit/47bc565ca7990a2de20af4030baf08ac62739aca
is a arguably a new feature and likely should not be backported. What is the
reason
@@ -190,19 +190,29 @@ def ApplyConversionPatternsOp :
TransformDialectOp<"apply_conversion_patterns",
The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects`
attributes specify the conversion target.
-This transform consumes the `target` handle and mo
https://github.com/ftynse approved this pull request.
https://github.com/llvm/llvm-project/pull/83950
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
https://github.com/ftynse edited https://github.com/llvm/llvm-project/pull/83950
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
@@ -632,7 +663,11 @@ LogicalResult
transform::ApplyConversionPatternsOp::verify() {
void transform::ApplyConversionPatternsOp::getEffects(
SmallVectorImpl &effects) {
- transform::consumesHandle(getTarget(), effects);
+ if (!getPreserveHandles()) {
+transform::consu
43 matches
Mail list logo