[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
momchil-velikov wrote:
Fixed.
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/banach-space edited https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
banach-space wrote:
> // The collapsed dimensions (excluding the scalable one) of the vector and
>// the memref must match
What about dynamic dim sizes in the memref? If that's not supported, is there a
test?
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
momchil-velikov wrote:
Done.
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated
https://github.com/llvm/llvm-project/pull/143146
>From 198ed819841270aeec7159fe2a9a4c092b8d8af7 Mon Sep 17 00:00:00 2001
From: Momchil Velikov
Date: Wed, 14 May 2025 09:03:49 +
Subject: [PATCH 1/4] [MLIR] Legalize certain `vector.transfer_read` ops of
scalable vectors
THis patch add a transform of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
scalable dimension in the rightmost position.
---
.../Transforms/LegalizeVectorStorage.cpp | 110 -
.../ArmSVE/legalize-transfer-read.mlir| 226 ++
.../transfer-read-scalable-not-rightmost.mlir | 72 ++
3 files changed, 407 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
create mode 100644
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+// Collapse the trailing dimensions of the memref.
+SmallVector reassoc;
+for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+ reassoc.push_back({i});
+for (int64_t i = memTy.getRank() - numCollapseDims + 1; i <
memTy.getRank();
+ ++i)
+ reassoc.
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
momchil-velikov wrote:
Comment added.
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
momchil-velikov wrote:
This part wasn't tested at all. Test cases added.
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
momchil-velikov wrote:
Done.
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
banach-space wrote:
[nit] It would be helpful to add _why_:
* Don't need to worry about 1D, that's supported by default.
* More than 1 scalable dims are tricky (how to collapse e.g. `vscale * vscale`?)
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
banach-space wrote:
[nit]
[getNumScalableDims](https://github.com/banach-space/llvm-project/blob/c15e7dddaea765eab4f9ed73e79b762138dc4ac0/mlir/include/mlir/IR/BuiltinTypes.td#L1368-L1371)
would be more canonical then `llvm::count`
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + +// - + +// CHECK-LABEL: @test_base_case +// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]: banach-space wrote: Is it guaranteed that `%i` will be renamed as `arg0` after the transformation? AFAIK, no, but perhaps I am missing something? https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
banach-space wrote:
Would supporting non-identity be a problem? It would be good to add a comment,
either:
* `TODO: We haven't required this, so leaving for later.` or
* "Too complex because of , disabling".
Any hint for future developers would be helpful.
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+// Do not try to transform masked reads. For example, if we have a transfer
+// to a `vector<[4]x4xi8>` we could have a mask like
+//1 1 1 0
+//1 1 1 0
+//1 1 1 0
+//0 0 0 0
+// Flattening this mask would look like
+//1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+// and we have not yet figured out an efficient way to build such a mask,
+// neither from the mask operand, nor from the original
`vector.create_mask`
+// operation (if visible at all).
+if (readOp.isMasked() || readOp.getMask())
+ return rewriter.notifyMatchFailure(readOp,
+ "masked transfers not-supported");
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
banach-space wrote:
Note, it's not really index that's out-of-bounds, but the corresponding memory
access. So, index could be in-bounds, but we might be reading "more" than
there's available to read (starting at that index). For example:
```mlir
vector.transfer_read %mem[5] : memref<7xi8>, vector<7xi8>
```
```suggestion
"out-of-bounds index to collapse");
```
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
@@ -0,0 +1,262 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s |
FileCheck %s
+
+// -
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref into memref
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref) ->
vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>,
vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M :
memref) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:: memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:: memref>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true, true]} : memref, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref> into
+// CHECK-SAME: memref>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read
%[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] :
vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false,
true, true, true]} : memref, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M
: memref) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true,
true]} : memref, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M :
memref) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]}
: memref, vector<[8]xi8>
+
+ return %A : ve
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
banach-space wrote: [nit] Avoid using the word `test` in test function names. It's just noise that doesn't add any new info. Instead, try to convey what makes a particular test case unique. See here for MLIR guidelines: https://mlir.llvm.org/getting_started/TestingGuide/#test-formatting-best-practices https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/banach-space commented: Great work, Momchil - thank you! I've left a number of comments, but nothing major. My main high-level suggestion is to follow the guidance in [MLIR's Testing Guide](https://mlir.llvm.org/getting_started/TestingGuide/#contributor-guidelines) a bit more closely. It’s a relatively new (and long!) document, so I’ve included specific in-line suggestions to make it easier to see where things could align better. For additional context, this [RFC](https://discourse.llvm.org/t/rfc-should-we-aim-for-more-consistency-in-tests/) provides some of the rationale behind that approach. Also - what about memrefs with dynamic dimensions? https://github.com/llvm/llvm-project/pull/143146 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated
https://github.com/llvm/llvm-project/pull/143146
>From 493955781f28b8b6caaeff1b45f7b7a06b43086c Mon Sep 17 00:00:00 2001
From: Momchil Velikov
Date: Wed, 14 May 2025 09:03:49 +
Subject: [PATCH 1/3] [MLIR] Legalize certain `vector.transfer_read` ops of
scalable vectors
THis patch add a transform of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
scalable dimension in the rightmost position.
---
.../Transforms/LegalizeVectorStorage.cpp | 110 -
.../ArmSVE/legalize-transfer-read.mlir| 226 ++
.../transfer-read-scalable-not-rightmost.mlir | 72 ++
3 files changed, 407 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
create mode 100644
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+// Collapse the trailing dimensions of the memref.
+SmallVector reassoc;
+for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+ reassoc.push_back({i});
+for (int64_t i = memTy.getRank() - numCollapseDims + 1; i <
memTy.getRank();
+ ++i)
+ reassoc.
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated
https://github.com/llvm/llvm-project/pull/143146
>From 6a6d6037b6da51b2da474c99751433542cf35603 Mon Sep 17 00:00:00 2001
From: Momchil Velikov
Date: Wed, 14 May 2025 09:03:49 +
Subject: [PATCH 1/2] [MLIR] Legalize certain `vector.transfer_read` ops of
scalable vectors
THis patch add a transform of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
scalable dimension in the rightmost position.
---
.../Transforms/LegalizeVectorStorage.cpp | 110 -
.../ArmSVE/legalize-transfer-read.mlir| 226 ++
.../transfer-read-scalable-not-rightmost.mlir | 72 ++
3 files changed, 407 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
create mode 100644
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+// Collapse the trailing dimensions of the memref.
+SmallVector reassoc;
+for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+ reassoc.push_back({i});
+for (int64_t i = memTy.getRank() - numCollapseDims + 1; i <
memTy.getRank();
+ ++i)
+ reassoc.
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated
https://github.com/llvm/llvm-project/pull/143146
>From 6a6d6037b6da51b2da474c99751433542cf35603 Mon Sep 17 00:00:00 2001
From: Momchil Velikov
Date: Wed, 14 May 2025 09:03:49 +
Subject: [PATCH 1/2] [MLIR] Legalize certain `vector.transfer_read` ops of
scalable vectors
THis patch add a transform of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
scalable dimension in the rightmost position.
---
.../Transforms/LegalizeVectorStorage.cpp | 110 -
.../ArmSVE/legalize-transfer-read.mlir| 226 ++
.../transfer-read-scalable-not-rightmost.mlir | 72 ++
3 files changed, 407 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
create mode 100644
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+// Collapse the trailing dimensions of the memref.
+SmallVector reassoc;
+for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+ reassoc.push_back({i});
+for (int64_t i = memTy.getRank() - numCollapseDims + 1; i <
memTy.getRank();
+ ++i)
+ reassoc.
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated
https://github.com/llvm/llvm-project/pull/143146
>From 4d13aa2c48a24e5a76618e78adde41713115f895 Mon Sep 17 00:00:00 2001
From: Momchil Velikov
Date: Wed, 14 May 2025 09:03:49 +
Subject: [PATCH] [MLIR] Legalize certain `vector.transfer_read` ops of
scalable vectors
THis patch add a transform of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
scalable dimension in the rightmost position.
---
.../Transforms/LegalizeVectorStorage.cpp | 110 -
.../ArmSVE/legalize-transfer-read.mlir| 226 ++
.../transfer-read-scalable-not-rightmost.mlir | 72 ++
3 files changed, 407 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
create mode 100644
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+// Collapse the trailing dimensions of the memref.
+SmallVector reassoc;
+for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+ reassoc.push_back({i});
+for (int64_t i = memTy.getRank() - numCollapseDims + 1; i <
memTy.getRank();
+ ++i)
+ reassoc.back
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov updated
https://github.com/llvm/llvm-project/pull/143146
>From 4d13aa2c48a24e5a76618e78adde41713115f895 Mon Sep 17 00:00:00 2001
From: Momchil Velikov
Date: Wed, 14 May 2025 09:03:49 +
Subject: [PATCH] [MLIR] Legalize certain `vector.transfer_read` ops of
scalable vectors
THis patch add a transform of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
scalable dimension in the rightmost position.
---
.../Transforms/LegalizeVectorStorage.cpp | 110 -
.../ArmSVE/legalize-transfer-read.mlir| 226 ++
.../transfer-read-scalable-not-rightmost.mlir | 72 ++
3 files changed, 407 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
create mode 100644
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+// Collapse the trailing dimensions of the memref.
+SmallVector reassoc;
+for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+ reassoc.push_back({i});
+for (int64_t i = memTy.getRank() - numCollapseDims + 1; i <
memTy.getRank();
+ ++i)
+ reassoc.back
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
https://github.com/momchil-velikov created
https://github.com/llvm/llvm-project/pull/143146
THis patch add a transform of `transfer_read` operation to change the vector
type to one that can be mapped to an LLVM type. This is done by collapsing
trailing dimensions so we obtain a vector type with a single scalable dimension
in the rightmost position.
>From 62ad29ddc4d8c1ffa7c5af5dbadd9bb0647964ea Mon Sep 17 00:00:00 2001
From: Momchil Velikov
Date: Wed, 14 May 2025 09:03:49 +
Subject: [PATCH] [MLIR] Legalize certain `vector.transfer_read` ops of
scalable vectors
THis patch add a transform of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
scalable dimension in the rightmost position.
---
.../Transforms/LegalizeVectorStorage.cpp | 110 -
.../ArmSVE/legalize-transfer-read.mlir| 226 ++
.../transfer-read-scalable-not-rightmost.mlir | 72 ++
3 files changed, 407 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
create mode 100644
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+// Collapse the trailing dimension
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
github-actions[bot] wrote:
:warning: C/C++ code formatter, clang-format found issues in your code.
:warning:
You can test this locally with the following command:
``bash
git-clang-format --diff HEAD~1 HEAD --extensions cpp --
mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
``
View the diff from clang-format here.
``diff
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index f16d33c00..da36f346c 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -409,13 +409,13 @@ struct LegalizeTransferRead : public
OpRewritePattern {
void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
RewritePatternSet &patterns) {
- patterns.add,
- LegalizeSVEMaskAllocation,
- LegalizeSVEMaskTypeCastConversion,
- LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion,
- LegalizeTransferRead>(
- patterns.getContext());
+ patterns
+ .add,
+ LegalizeSVEMaskAllocation,
+ LegalizeSVEMaskTypeCastConversion, LegalizeSVEMaskStoreConversion,
+ LegalizeSVEMaskLoadConversion, LegalizeTransferRead>(
+ patterns.getContext());
}
namespace {
``
https://github.com/llvm/llvm-project/pull/143146
___
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [MLIR] Legalize certain `vector.transfer_read` ops of scalable vectors (PR #143146)
llvmbot wrote:
@llvm/pr-subscribers-mlir-sve
Author: Momchil Velikov (momchil-velikov)
Changes
THis patch add a transform of `transfer_read` operation to change the vector
type to one that can be mapped to an LLVM type. This is done by collapsing
trailing dimensions so we obtain a vector type with a single scalable dimension
in the rightmost position.
---
Full diff: https://github.com/llvm/llvm-project/pull/143146.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
(+109-1)
- (added) mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir (+226)
- (added)
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
(+72)
``diff
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public
OpRewritePattern {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref into memref
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+PatternRewriter &rewriter) const override {
+
+if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+// We handle transfers of vectors with rank >= 2 and a single scalable
+// dimension.
+VectorType origVT = readOp.getVectorType();
+ArrayRef origScalableDims = origVT.getScalableDims();
+const int64_t origVRank = origVT.getRank();
+if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+// Number of trailing dimensions to collapse, including the scalable
+// dimension. Nothing to do if the single scalable dimension is already
the
+// last one.
+const int64_t numCollapseDims = std::distance(
+llvm::find(origScalableDims, true), origScalableDims.end());
+if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+// We want a simple memref (not a tensor) with contiguous elements for at
+// least all the trailing dimensions up to and including the scalable one.
+auto memTy = dyn_cast(readOp.getBase().getType());
+if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+// The collapsed dimensions (excluding the scalable one) of the vector and
+// the memref must match and the corresponding indices must be in-bounds
(it
+// follows these indices would be zero). This guarantees that the operation
+// transfers a contiguous block.
+if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+SmallVector origInBounds = readOp.getInBoundsValues();
+if (!llvm::all_of(
+ArrayRef(origInBounds).take_back(numCollapseDims - 1),
+[](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+// Collapse the trailing dimensions of the memref.
+SmallVector reassoc;
+for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+ reassoc.push_back({i});
+for (int64_t i = memTy.getRank() - numCollapseDims + 1; i <
memTy.getRank();
+ ++i)
+ reassoc.back().push_back(i);
+if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
+ return failure();
+Value collapsedMem = rewriter.create(
+readOp.getLoc(), readOp.getBase(), reassoc);
+
+// Get a vector type with collapsed trailing dimens
