Author: MaheshRavishankar
Date: 2021-01-15T13:55:35-08:00
New Revision: d7bc3b7ce23b664d6620cdc32370a8614523ca2f

URL: 
https://github.com/llvm/llvm-project/commit/d7bc3b7ce23b664d6620cdc32370a8614523ca2f
DIFF: 
https://github.com/llvm/llvm-project/commit/d7bc3b7ce23b664d6620cdc32370a8614523ca2f.diff

LOG: [mlir][Linalg] Add missing check to canonicalization of GenericOp that are 
identity ops.

The operantion is an identity if the values yielded by the operation
is the argument of the basic block of that operation. Add this missing check.

Differential Revision: https://reviews.llvm.org/D94819

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp 
b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 30a6b9c0c371..fa98ed0cfbc9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2276,13 +2276,15 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
     SmallVector<Value, 4> returnedArgs;
     for (Value yieldVal : yieldOp.values()) {
       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
-      if (!yieldArg)
+      if (!yieldArg || yieldArg.getOwner() != &body)
         return failure();
       unsigned argumentNumber = yieldArg.getArgNumber();
       if (argumentNumber < numIndexArgs)
         return failure();
       returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs));
     }
+    if (returnedArgs.size() != genericOp.getOperation()->getNumResults())
+      return failure();
     rewriter.replaceOp(genericOp, returnedArgs);
     return success();
   }

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir 
b/mlir/test/Dialect/Linalg/canonicalize.mlir
index ca7f82c1b254..cc00b98d376c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -615,3 +615,56 @@ func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : 
tensor<?x?x?xf32>)
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 //       CHECK:     return %[[ARG1]], %[[ARG0]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cst = constant 1.000000e+00 : f32
+  %0 = dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  br ^bb1(%cst : f32)
+
+^bb1(%arg1 : f32):
+  %3 = linalg.generic
+    {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) {
+    ^bb0(%arg2: f32, %arg3 : f32):
+      linalg.yield %arg1 : f32
+    } -> tensor<?x?xf32>
+  return %3 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @keep_not_noop
+//       CHECK:   %[[RESULT:.+]] = linalg.generic
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cst = constant 1.000000e+00 : f32
+  %0 = dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  br ^bb1(%cst : f32)
+
+^bb1(%arg2 : f32):
+  %3:2 = linalg.generic
+    {indexing_maps = [#map, #map, #map, #map],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4 : f32, %arg5 : f32, %arg6 : f32):
+      linalg.yield %arg2, %arg4 : f32, f32
+    } -> tensor<?x?xf32>, tensor<?x?xf32>
+  return %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @keep_not_noop
+//       CHECK:   %[[RESULT:.+]]:2 = linalg.generic
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to