Author: Hanhan Wang
Date: 2020-12-03T23:11:29-08:00
New Revision: f5f1a5c2448e31f3c7e6f85b378372a02f8d3e43

URL: 
https://github.com/llvm/llvm-project/commit/f5f1a5c2448e31f3c7e6f85b378372a02f8d3e43
DIFF: 
https://github.com/llvm/llvm-project/commit/f5f1a5c2448e31f3c7e6f85b378372a02f8d3e43.diff

LOG: [mlir][Linalg] Handle fusion on tensors for projected permutation.

In the past, the reshape op can be folded only if the indexing map is
permutation in consumer's usage. We can relax to condition to be projected
permutation.

This patch still limits the fusion for scalar cases. Scalar case is a corner
case, because we need to decide where to put extra dims.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D92466

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/reshape_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h 
b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index fb916d3962e3..3df609f295cc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -118,11 +118,12 @@ Optional<SmallVector<Value, 1>> 
fuseTensorOps(PatternRewriter &rewriter,
 /// dimension is statically known, or -1 otherwise.
 SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
 
-/// Returns the statically-known loop ranges of the `linalgOp`. Applies the
-/// inverse of the concatenated indexing maps to the result of 
`getStaticShape`.
-/// Returns None if inverting the concatenated indexing map fails. Returns -1
+/// Returns the statically-known loop ranges of the `linalgOp`. Composes
+/// `linalgOp.getShapesToLoopsMap()` with the result of `getStaticShape`.
+/// Returns None if `linalgOp.getShapesToLoopsMap()` fails. Returns -1
 /// for non-statically-known loop ranges.
 Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
+
 /// Apply the permutation defined by `permutation` to `inVec`.
 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
 /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation 
vector

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp 
b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index fea80fac76a5..22e03c1e2f92 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -411,21 +411,19 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp 
linalgOp,
                                                unsigned fusedTensorIndex) {
   // Is fusable only if:
   // - The linalgOp is a generic op, or an indexed_generic.
-  // - All the indexing maps for operands in linalgOp are projected
+  // - All the indexing maps for operands and results in linalgOp are projected
   //   permutations.
-  // - The indexing map at the position representing the fused tensor is a
-  //   permutation.
+  // - The fused tensor is not a scalar.
   // - All the loops in linalgOp are parallel loops.
   return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) &&
          linalgOp.hasTensorSemantics() &&
-         llvm::all_of(linalgOp.indexing_maps().getValue().take_front(
-                          linalgOp.getNumInputs()),
+         llvm::all_of(linalgOp.indexing_maps().getValue(),
                       [](Attribute attr) {
                         return attr.cast<AffineMapAttr>()
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         linalgOp.getIndexingMap(fusedTensorIndex).isPermutation() &&
+         linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
          llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) {
            return attr.cast<StringAttr>().getValue() ==
                   getParallelIteratorTypeName();
@@ -446,8 +444,6 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, 
TensorReshapeOp reshapeOp,
       reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
   RankedTensorType expandedType =
       isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
-  RankedTensorType foldedType =
-      isExpanding ? reshapeOp.getSrcType() : reshapeOp.getResultType();
   AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
 
   // The reshape is folding/expanding consecutive dimensions. Given the 
indexing
@@ -455,9 +451,15 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, 
TensorReshapeOp reshapeOp,
   // the original op is expanded into. Also record the shape of the expanded
   // dimensions.
   ArrayRef<int64_t> expandedShape = expandedType.getShape();
-  SmallVector<unsigned, 4> numFoldedDims(foldedType.getRank(), 0);
+  Optional<SmallVector<int64_t, 4>> origOpLoopRange =
+      getStaticLoopRanges(linalgOp);
+  if (!origOpLoopRange) {
+    linalgOp.emitError("unable to find loop range for operation");
+    return llvm::None;
+  }
+  SmallVector<unsigned, 4> numFoldedDims(fusedIndexMap.getNumDims(), 1);
   SmallVector<SmallVector<int64_t, 4>, 4> expandedDimsShape(
-      foldedType.getRank());
+      fusedIndexMap.getNumDims());
   auto reassociationMaps = reshapeOp.getReassociationMaps();
   for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
@@ -467,6 +469,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, 
TensorReshapeOp reshapeOp,
         expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
     expandedDimsShape[pos].assign(shape.begin(), shape.end());
   }
+  // The remaining dimensions remain the same.
+  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
+    if (expandedDimsShape[i].empty())
+      expandedDimsShape[i] = {(*origOpLoopRange)[i]};
 
   if (isa<IndexedGenericOp>(linalgOp.getOperation())) {
     // For indexed generic op, the region contains arguments that represent the
@@ -476,6 +482,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, 
TensorReshapeOp reshapeOp,
     // front) are statically know. For dynamic case, we would need shape
     // information on these dimensions to get these.
     for (auto &expandedShape : expandedDimsShape) {
+      if (expandedShape.size() == 1)
+        continue;
       for (int64_t expandedDimShape : llvm::make_range(
                std::next(expandedShape.begin()), expandedShape.end())) {
         if (ShapedType::isDynamic(expandedDimShape)) {

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp 
b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 43f40163da81..8e60312bf4fd 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -104,13 +104,18 @@ SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) 
{
     auto shape = v.getType().cast<ShapedType>().getShape();
     res.append(shape.begin(), shape.end());
   }
+  if (linalgOp.getNumInitTensors())
+    return res;
+  for (Value v : linalgOp.getOperation()->getResults()) {
+    auto shape = v.getType().cast<ShapedType>().getShape();
+    res.append(shape.begin(), shape.end());
+  }
   return res;
 }
 
 Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
   SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
-  AffineMap invertedMap =
-      inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
+  AffineMap invertedMap = linalgOp.getShapesToLoopsMap();
   if (!invertedMap)
     return {};
   return invertedMap.compose(viewSizes);

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir 
b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 1f201f78fe74..66e07cc56d65 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -344,3 +344,97 @@ func @reshape_as_consumer_permutation
 //       CHECK:       %[[T9:.+]] = addi %[[T7]], %[[T8]]
 //       CHECK:       %[[T10:.+]] = index_cast %[[ARG7]]
 //       CHECK:       %[[T11:.+]] = addi %[[T9]], %[[T10]]
+
+// -----
+
+func @reshape_as_producer_projected_permutation
+  (%arg0 : tensor<33x8x?xi32>) -> tensor<264x?x4xi32> {
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>,
+                                    affine_map<(d0, d1, d2) -> (d2)>]
+    : tensor<33x8x?xi32> into tensor<264x?xi32>
+  %1 = linalg.indexed_generic
+    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
+                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+     iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : 
tensor<264x?xi32>) {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32):  // no 
predecessors
+    %2 = index_cast %arg1 : index to i32
+    %3 = addi %arg4, %2 : i32
+    %4 = index_cast %arg2 : index to i32
+    %5 = addi %3, %4 : i32
+    %6 = index_cast %arg3 : index to i32
+    %7 = addi %5, %6 : i32
+    linalg.yield %7 : i32
+  } -> tensor<264x?x4xi32>
+  return %1 : tensor<264x?x4xi32>
+}
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)>
+//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//       CHECK: @reshape_as_producer_projected_permutation
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<33x8x?xi32>
+//       CHECK:   %[[RES:.+]] = linalg.indexed_generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:     ins(%[[ARG0]] : tensor<33x8x?xi32>)
+//       CHECK:   ^{{.+}}(
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9]+]]: i32)
+//       CHECK:       %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG1]], %[[ARG2]])
+//       CHECK:       %[[T1:.+]] = index_cast %[[T0]] : index to i32
+//       CHECK:       %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32
+//       CHECK:       %[[T3:.+]] = index_cast %[[ARG3]] : index to i32
+//       CHECK:       %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32
+//       CHECK:       %[[T5:.+]] = index_cast %[[ARG4]] : index to i32
+//       CHECK:       %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32
+//       CHECK:       linalg.yield %[[T6]] : i32
+//       CHECK:    %[[RES2:.+]] = linalg.tensor_reshape %[[RES]]
+//  CHECK-SAME:      [#[[MAP3]], #[[MAP4]], #[[MAP5]]]
+//  CHECK-SAME:    : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
+//       CHECK:  return %[[RES2]] : tensor<264x?x4xi32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
+                                                   %arg1 : tensor<?x?xf32>) ->
+                                                   tensor<?x?x4x5xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map0, #map1],
+     iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+//      CHECK: func @generic_op_reshape_consumer_fusion_projected
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, 
tensor<?x4x5x?xf32>)
+//      CHECK:   return %[[T2]] : tensor<?x?x4x5xf32>


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to