Author: Alexander Belyaev Date: 2021-01-05T15:15:21+01:00 New Revision: 89ae5b5b6a475addb7248ca7a948a944a15f0275
URL: https://github.com/llvm/llvm-project/commit/89ae5b5b6a475addb7248ca7a948a944a15f0275 DIFF: https://github.com/llvm/llvm-project/commit/89ae5b5b6a475addb7248ca7a948a944a15f0275.diff LOG: [mlir] Add canonicalization pattern out_tensor->linalg->dim to out_tensor->dim. Differential Revision: https://reviews.llvm.org/D94079 Added: Modified: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp mlir/test/Dialect/Linalg/canonicalize.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index bcbd6d9036121..529ba35a0b87d 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1958,14 +1958,33 @@ struct DeduplicateInputs : public RewritePattern { return success(); } }; + +/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg +/// with the corresponding output tensor argument of the linalg op. +struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> { + using OpRewritePattern<DimOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + Value dimOpArg = dimOp.memrefOrTensor(); + auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>(); + if (!linalgOp) + return failure(); + + auto results = linalgOp.getOperation()->getResults(); + int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg)); + auto outputTensors = linalgOp.getOutputTensors(); + rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index()); + return success(); + } +}; } // namespace #define CANONICALIZERS_AND_FOLDERS(XXX) \ void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ MLIRContext *context) { \ - results.insert<EraseDeadLinalgOp>(); \ - results.insert<FoldTensorCastOp>(); \ - results.insert<DeduplicateInputs>(); \ + results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp>(); \ + results.insert<ReplaceDimOfLinalgResult>(context); \ } \ \ LogicalResult XXX::fold(ArrayRef<Attribute>, \ diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index f015d5fd64fd9..faac64c0d91a9 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -389,3 +389,31 @@ func @init_tensor_dynamic_dim(%arg0 : index) -> (index) { // CHECK: func @init_tensor_dynamic_dim // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index // CHECK: return %[[ARG0]] + +// ----- + +#map = affine_map<(d0) -> (d0)> + +func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>, + %arg_1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { + %0, %1 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel"] + } ins(%arg_0 : tensor<?xf32>) + outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) { + ^bb0(%in: f32, %out_0: f32, %out_1: f32): + linalg.yield %in, %in : f32, f32 + } -> tensor<?xf32>, tensor<?xf32> + + %c0 = constant 0 : index + %num_elem_0 = dim %0, %c0 : tensor<?xf32> + %result_0 = linalg.init_tensor [%num_elem_0] : tensor<?xf32> + + %num_elem_1 = dim %1, %c0 : tensor<?xf32> + %result_1 = linalg.init_tensor [%num_elem_1] : tensor<?xf32> + return %result_0, %result_1 : tensor<?xf32>, tensor<?xf32> +} +// CHECK-LABEL: func @init_tensor_dim_of_linalg_result( +// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>) +// CHECK: dim [[ARG_0]] +// CHECK: dim [[ARG_1]] _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits