Author: MaheshRavishankar Date: 2021-01-14T14:59:24-08:00 New Revision: 722ae10907e06a0bafa00c557e5242b53419a3ce
URL: https://github.com/llvm/llvm-project/commit/722ae10907e06a0bafa00c557e5242b53419a3ce DIFF: https://github.com/llvm/llvm-project/commit/722ae10907e06a0bafa00c557e5242b53419a3ce.diff LOG: [mlir][Linalg] Add canonicalization to remove no-op linalg operations. linalg.generic/indexed_generic operations on tensors whose body is just yielding the (non-induction variable) arguments of the operation can be canonicalized by replacing uses of the result with the corresponding arguments. Differential Revision: https://reviews.llvm.org/D94581 Added: Modified: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp mlir/test/Dialect/Linalg/canonicalize.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 8732065bb042..b74e44d91176 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2119,6 +2119,54 @@ struct DeduplicateInputs : public RewritePattern { } }; +/// Remove generic/indexed_generic operations (on tensors) that are just copying +/// the values from inputs to the results. Requirements are +/// 1) All iterator types are parallel +/// 2) The body contains just a yield operation with the yielded values being +/// the arguments corresponding to the operands. +struct RemoveIdentityLinalgOps : public RewritePattern { + RemoveIdentityLinalgOps(PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!isa<GenericOp, IndexedGenericOp>(op)) + return failure(); + LinalgOp genericOp = cast<LinalgOp>(op); + if (!genericOp.hasTensorSemantics()) + return failure(); + // Check all indexing maps are identity. + if (llvm::any_of(genericOp.getIndexingMaps(), + [](AffineMap map) { return !map.isIdentity(); })) + return failure(); + + // Check that the body of the linalg operation is just a linalg.yield + // operation. + Block &body = op->getRegion(0).front(); + if (!llvm::hasSingleElement(body)) + return failure(); + auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); + if (!yieldOp) + return failure(); + + // Get the argument number of the returned values. That is the operand + // number to use for replacing uses of this operation. + unsigned numIndexArgs = genericOp.getNumPayloadInductionVariables(); + SmallVector<Value, 4> returnedArgs; + for (Value yieldVal : yieldOp.values()) { + auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); + if (!yieldArg) + return failure(); + unsigned argumentNumber = yieldArg.getArgNumber(); + if (argumentNumber < numIndexArgs) + return failure(); + returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs)); + } + rewriter.replaceOp(genericOp, returnedArgs); + return success(); + } +}; + /// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg /// with the corresponding output tensor argument of the linalg op. struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> { @@ -2143,7 +2191,8 @@ struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> { #define CANONICALIZERS_AND_FOLDERS(XXX) \ void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ MLIRContext *context) { \ - results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp>(); \ + results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \ + RemoveIdentityLinalgOps>(); \ results.insert<ReplaceDimOfLinalgResult>(context); \ } \ \ diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 6b806c801341..b2de3fdc6c8e 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -249,8 +249,10 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf return %1: tensor<0xf32> } // CHECK-LABEL: @dce_zero_memref +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32> // CHECK-NOT: linalg.copy -// CHECK-NEXT: linalg.generic +// CHECK-NEXT: return %[[ARG1]] // ----- @@ -449,3 +451,30 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { // CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]] // CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]] // CHECK: return %[[T1]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) + -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %0 = dim %arg0, %c0 : tensor<?x?x?xf32> + %1 = dim %arg0, %c1 : tensor<?x?x?xf32> + %2 = dim %arg0, %c2 : tensor<?x?x?xf32> + %3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32> + %4, %5 = linalg.generic { + indexing_maps = [#map, #map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs(%3, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32): + linalg.yield %arg3, %arg2 : f32, f32 + } -> tensor<?x?x?xf32>, tensor<?x?x?xf32> + return %4, %5 : tensor<?x?x?xf32>, tensor<?x?x?xf32> +} +// CHECK-LABEL: func @remove_no_op +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> +// CHECK: return %[[ARG1]], %[[ARG0]] _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits