[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
https://github.com/PeimingLiu created
https://github.com/llvm/llvm-project/pull/105566
[mlir][sparse] refactoring sparse_tensor.iterate lowering pattern
implementation.
>From 1a32495b27dfd003408dd5b4f12f3db7f0b73b5a Mon Sep 17 00:00:00 2001
From: Peiming Liu
Date: Thu, 15 Aug 2024 18:10:25 +
Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering
pattern implementation.
---
.../Transforms/SparseIterationToScf.cpp | 118 ++
1 file changed, 36 insertions(+), 82 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e4573..f7fcabb0220b50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public
OneToNOpConversionPattern {
std::unique_ptr it =
iterSpace.extractIterator(rewriter, loc);
-if (it->iteratableByFor()) {
- auto [lo, hi] = it->genForCond(rewriter, loc);
- Value step = constantIndex(rewriter, loc, 1);
- SmallVector ivs;
- for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
- scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs);
-
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- rewriter.eraseBlock(forOp.getBody());
- Region &dstRegion = forOp.getRegion();
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
- auto yieldOp =
- llvm::cast(forOp.getBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(forOp.getBody());
- // replace sparse_tensor.yield with scf.yield.
- rewriter.create(loc, yieldOp.getResults());
- rewriter.eraseOp(yieldOp);
-
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
- rewriter.replaceOp(op, forOp.getResults(), resultMapping);
-} else {
- SmallVector ivs;
- // TODO: put iterator at the end of argument list to be consistent with
- // coiterate operation.
- llvm::append_range(ivs, it->getCursor());
- for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
-
- assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
- TypeRange types = ValueRange(ivs).getTypes();
- auto whileOp = rewriter.create(loc, types, ivs);
- SmallVector l(types.size(), op.getIterator().getLoc());
-
- // Generates loop conditions.
- Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
- rewriter.setInsertionPointToStart(before);
- ValueRange bArgs = before->getArguments();
- auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
- assert(remArgs.size() == adaptor.getInitArgs().size());
- rewriter.create(loc, whileCond,
before->getArguments());
-
- // Generates loop body.
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- Region &dstRegion = whileOp.getAfter();
- // TODO: handle uses of coordinate!
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
- ValueRange aArgs = whileOp.getAfterArguments();
- auto yieldOp = llvm::cast(
- whileOp.getAfterBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+SmallVector ivs;
+for (ValueRange inits : adaptor.getInitArgs())
+ llvm::append_range(ivs, inits);
+
+// Type conversion on iterate op block.
+OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+if (failed(typeConverter->convertSignatureArgs(
+op.getBody()->getArgumentTypes(), blockTypeMapping)))
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert iterate region argurment types");
+rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+Block *block = op.getBody();
+ValueRange ret = genLoopWithIterator(
+rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+SparseIterator *it, ValueRange reduc) -> SmallVector {
+ SmallVector blockArgs(it->getCursor());
+ // TODO: Also appends coordinates if used.
+ // blockArgs.push_back(it->deref(rewriter, loc));
+ llvm::a
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu created
https://github.com/llvm/llvm-project/pull/105567
[mlir][sparse] unify block arguments order between iterate/coiterate operations.
>From 6fd099fb7039f8fda37d50f1c44cd7afd62cafb7 Mon Sep 17 00:00:00 2001
From: Peiming Liu
Date: Thu, 15 Aug 2024 21:10:37 +
Subject: [PATCH] [mlir][sparse] unify block arguments order between
iterate/coiterate operations.
---
.../SparseTensor/IR/SparseTensorOps.td| 7 ++--
.../SparseTensor/IR/SparseTensorDialect.cpp | 31
.../Transforms/SparseIterationToScf.cpp | 36 ++-
3 files changed, 31 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
return getIterSpace().getType().getSpaceDim();
}
BlockArgument getIterator() {
- return getRegion().getArguments().front();
+ return getRegion().getArguments().back();
}
std::optional getLvlCrd(Level lvl) {
if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
return std::nullopt;
}
Block::BlockArgListType getCrds() {
- // The first block argument is iterator, the remaining arguments are
- // referenced coordinates.
- return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ // User-provided iteration arguments -> coords -> iterator.
+ return getRegion().getArguments().slice(getNumRegionIterArgs(),
getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs() {
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
parser.getNameLoc(),
"mismatch in number of sparse iterators and sparse spaces");
- if (failed(parseUsedCoordList(parser, state, blockArgs)))
+ SmallVector coords;
+ if (failed(parseUsedCoordList(parser, state, coords)))
return failure();
- size_t numCrds = blockArgs.size();
+ size_t numCrds = coords.size();
// Parse "iter_args(%arg = %init, ...)"
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
if (parser.parseAssignmentList(blockArgs, initArgs))
return failure();
+ blockArgs.append(coords);
+
SmallVector iterSpaceTps;
// parse ": sparse_tensor.iter_space -> ret"
if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
if (hasIterArgs) {
// Strip off leading args that used for coordinates.
-MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder,
OperationState &odsState,
odsState.addTypes(initArgs.getTypes());
Block *bodyBlock = builder.createBlock(bodyRegion);
- // First argument, sparse iterator
- bodyBlock->addArgument(
- llvm::cast(iterSpace.getType()).getIteratorType(),
- odsState.location);
+ // Starts with a list of user-provided loop arguments.
+ for (Value v : initArgs)
+bodyBlock->addArgument(v.getType(), v.getLoc());
- // Followed by a list of used coordinates.
+ // Follows by a list of used coordinates.
for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
bodyBlock->addArgument(builder.getIndexType(), odsState.location);
- // Followed by a list of user-provided loop arguments.
- for (Value v : initArgs)
-bodyBlock->addArgument(v.getType(), v.getLoc());
+ // Ends with sparse iterator
+ bodyBlock->addArgument(
+ llvm::cast(iterSpace.getType()).getIteratorType(),
+ odsState.location);
}
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser,
OperationState &result) {
return parser.emitError(parser.getNameLoc(),
"expected only one iterator/iteration space");
- iters.append(iterArgs);
+ iterArgs.append(iters);
Region *body = result.addRegion();
[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
https://github.com/PeimingLiu updated
https://github.com/llvm/llvm-project/pull/105566
>From 937bcd814688e7c6f88ef27b7586254006e0d050 Mon Sep 17 00:00:00 2001
From: Peiming Liu
Date: Thu, 15 Aug 2024 18:10:25 +
Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering
pattern implementation.
stack-info: PR: https://github.com/llvm/llvm-project/pull/105566, branch:
users/PeimingLiu/stack/2
---
.../Transforms/SparseIterationToScf.cpp | 118 ++
1 file changed, 36 insertions(+), 82 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e4573..f7fcabb0220b50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public
OneToNOpConversionPattern {
std::unique_ptr it =
iterSpace.extractIterator(rewriter, loc);
-if (it->iteratableByFor()) {
- auto [lo, hi] = it->genForCond(rewriter, loc);
- Value step = constantIndex(rewriter, loc, 1);
- SmallVector ivs;
- for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
- scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs);
-
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- rewriter.eraseBlock(forOp.getBody());
- Region &dstRegion = forOp.getRegion();
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
- auto yieldOp =
- llvm::cast(forOp.getBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(forOp.getBody());
- // replace sparse_tensor.yield with scf.yield.
- rewriter.create(loc, yieldOp.getResults());
- rewriter.eraseOp(yieldOp);
-
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
- rewriter.replaceOp(op, forOp.getResults(), resultMapping);
-} else {
- SmallVector ivs;
- // TODO: put iterator at the end of argument list to be consistent with
- // coiterate operation.
- llvm::append_range(ivs, it->getCursor());
- for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
-
- assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
- TypeRange types = ValueRange(ivs).getTypes();
- auto whileOp = rewriter.create(loc, types, ivs);
- SmallVector l(types.size(), op.getIterator().getLoc());
-
- // Generates loop conditions.
- Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
- rewriter.setInsertionPointToStart(before);
- ValueRange bArgs = before->getArguments();
- auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
- assert(remArgs.size() == adaptor.getInitArgs().size());
- rewriter.create(loc, whileCond,
before->getArguments());
-
- // Generates loop body.
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- Region &dstRegion = whileOp.getAfter();
- // TODO: handle uses of coordinate!
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
- ValueRange aArgs = whileOp.getAfterArguments();
- auto yieldOp = llvm::cast(
- whileOp.getAfterBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+SmallVector ivs;
+for (ValueRange inits : adaptor.getInitArgs())
+ llvm::append_range(ivs, inits);
+
+// Type conversion on iterate op block.
+OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+if (failed(typeConverter->convertSignatureArgs(
+op.getBody()->getArgumentTypes(), blockTypeMapping)))
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert iterate region argurment types");
+rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+Block *block = op.getBody();
+ValueRange ret = genLoopWithIterator(
+rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+SparseIterator *it, ValueRange reduc) -> SmallVector {
+ SmallVector blockArgs(it->getCursor());
+ // TODO: Also appends coordinates if used.
+ // blockArgs.push_back(it->deref(rewriter, loc));
+
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu updated
https://github.com/llvm/llvm-project/pull/105567
>From 3f83d7a1eadc1101fb96707ecd348925e5aaed70 Mon Sep 17 00:00:00 2001
From: Peiming Liu
Date: Thu, 15 Aug 2024 21:10:37 +
Subject: [PATCH] [mlir][sparse] unify block arguments order between
iterate/coiterate operations.
stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch:
users/PeimingLiu/stack/3
---
.../SparseTensor/IR/SparseTensorOps.td| 7 ++--
.../SparseTensor/IR/SparseTensorDialect.cpp | 31
.../Transforms/SparseIterationToScf.cpp | 36 ++-
3 files changed, 31 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
return getIterSpace().getType().getSpaceDim();
}
BlockArgument getIterator() {
- return getRegion().getArguments().front();
+ return getRegion().getArguments().back();
}
std::optional getLvlCrd(Level lvl) {
if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
return std::nullopt;
}
Block::BlockArgListType getCrds() {
- // The first block argument is iterator, the remaining arguments are
- // referenced coordinates.
- return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ // User-provided iteration arguments -> coords -> iterator.
+ return getRegion().getArguments().slice(getNumRegionIterArgs(),
getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs() {
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
parser.getNameLoc(),
"mismatch in number of sparse iterators and sparse spaces");
- if (failed(parseUsedCoordList(parser, state, blockArgs)))
+ SmallVector coords;
+ if (failed(parseUsedCoordList(parser, state, coords)))
return failure();
- size_t numCrds = blockArgs.size();
+ size_t numCrds = coords.size();
// Parse "iter_args(%arg = %init, ...)"
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
if (parser.parseAssignmentList(blockArgs, initArgs))
return failure();
+ blockArgs.append(coords);
+
SmallVector iterSpaceTps;
// parse ": sparse_tensor.iter_space -> ret"
if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
if (hasIterArgs) {
// Strip off leading args that used for coordinates.
-MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder,
OperationState &odsState,
odsState.addTypes(initArgs.getTypes());
Block *bodyBlock = builder.createBlock(bodyRegion);
- // First argument, sparse iterator
- bodyBlock->addArgument(
- llvm::cast(iterSpace.getType()).getIteratorType(),
- odsState.location);
+ // Starts with a list of user-provided loop arguments.
+ for (Value v : initArgs)
+bodyBlock->addArgument(v.getType(), v.getLoc());
- // Followed by a list of used coordinates.
+ // Follows by a list of used coordinates.
for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
bodyBlock->addArgument(builder.getIndexType(), odsState.location);
- // Followed by a list of user-provided loop arguments.
- for (Value v : initArgs)
-bodyBlock->addArgument(v.getType(), v.getLoc());
+ // Ends with sparse iterator
+ bodyBlock->addArgument(
+ llvm::cast(iterSpace.getType()).getIteratorType(),
+ odsState.location);
}
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser,
OperationState &result) {
return parser.emitError(parser.getNameLoc(),
"expected only one iterator/iteration space");
- iters.append(iterArgs);
+ iterArgs.append(iters);
Region *body = r
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu updated
https://github.com/llvm/llvm-project/pull/105567
>From 3f83d7a1eadc1101fb96707ecd348925e5aaed70 Mon Sep 17 00:00:00 2001
From: Peiming Liu
Date: Thu, 15 Aug 2024 21:10:37 +
Subject: [PATCH] [mlir][sparse] unify block arguments order between
iterate/coiterate operations.
stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch:
users/PeimingLiu/stack/3
---
.../SparseTensor/IR/SparseTensorOps.td| 7 ++--
.../SparseTensor/IR/SparseTensorDialect.cpp | 31
.../Transforms/SparseIterationToScf.cpp | 36 ++-
3 files changed, 31 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
return getIterSpace().getType().getSpaceDim();
}
BlockArgument getIterator() {
- return getRegion().getArguments().front();
+ return getRegion().getArguments().back();
}
std::optional getLvlCrd(Level lvl) {
if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
return std::nullopt;
}
Block::BlockArgListType getCrds() {
- // The first block argument is iterator, the remaining arguments are
- // referenced coordinates.
- return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ // User-provided iteration arguments -> coords -> iterator.
+ return getRegion().getArguments().slice(getNumRegionIterArgs(),
getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs() {
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
parser.getNameLoc(),
"mismatch in number of sparse iterators and sparse spaces");
- if (failed(parseUsedCoordList(parser, state, blockArgs)))
+ SmallVector coords;
+ if (failed(parseUsedCoordList(parser, state, coords)))
return failure();
- size_t numCrds = blockArgs.size();
+ size_t numCrds = coords.size();
// Parse "iter_args(%arg = %init, ...)"
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
if (parser.parseAssignmentList(blockArgs, initArgs))
return failure();
+ blockArgs.append(coords);
+
SmallVector iterSpaceTps;
// parse ": sparse_tensor.iter_space -> ret"
if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
if (hasIterArgs) {
// Strip off leading args that used for coordinates.
-MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder,
OperationState &odsState,
odsState.addTypes(initArgs.getTypes());
Block *bodyBlock = builder.createBlock(bodyRegion);
- // First argument, sparse iterator
- bodyBlock->addArgument(
- llvm::cast(iterSpace.getType()).getIteratorType(),
- odsState.location);
+ // Starts with a list of user-provided loop arguments.
+ for (Value v : initArgs)
+bodyBlock->addArgument(v.getType(), v.getLoc());
- // Followed by a list of used coordinates.
+ // Follows by a list of used coordinates.
for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
bodyBlock->addArgument(builder.getIndexType(), odsState.location);
- // Followed by a list of user-provided loop arguments.
- for (Value v : initArgs)
-bodyBlock->addArgument(v.getType(), v.getLoc());
+ // Ends with sparse iterator
+ bodyBlock->addArgument(
+ llvm::cast(iterSpace.getType()).getIteratorType(),
+ odsState.location);
}
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser,
OperationState &result) {
return parser.emitError(parser.getNameLoc(),
"expected only one iterator/iteration space");
- iters.append(iterArgs);
+ iterArgs.append(iters);
Region *body = r
[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
https://github.com/PeimingLiu updated
https://github.com/llvm/llvm-project/pull/105566
>From 937bcd814688e7c6f88ef27b7586254006e0d050 Mon Sep 17 00:00:00 2001
From: Peiming Liu
Date: Thu, 15 Aug 2024 18:10:25 +
Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering
pattern implementation.
stack-info: PR: https://github.com/llvm/llvm-project/pull/105566, branch:
users/PeimingLiu/stack/2
---
.../Transforms/SparseIterationToScf.cpp | 118 ++
1 file changed, 36 insertions(+), 82 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e4573..f7fcabb0220b50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public
OneToNOpConversionPattern {
std::unique_ptr it =
iterSpace.extractIterator(rewriter, loc);
-if (it->iteratableByFor()) {
- auto [lo, hi] = it->genForCond(rewriter, loc);
- Value step = constantIndex(rewriter, loc, 1);
- SmallVector ivs;
- for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
- scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs);
-
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- rewriter.eraseBlock(forOp.getBody());
- Region &dstRegion = forOp.getRegion();
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
- auto yieldOp =
- llvm::cast(forOp.getBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(forOp.getBody());
- // replace sparse_tensor.yield with scf.yield.
- rewriter.create(loc, yieldOp.getResults());
- rewriter.eraseOp(yieldOp);
-
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
- rewriter.replaceOp(op, forOp.getResults(), resultMapping);
-} else {
- SmallVector ivs;
- // TODO: put iterator at the end of argument list to be consistent with
- // coiterate operation.
- llvm::append_range(ivs, it->getCursor());
- for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
-
- assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
- TypeRange types = ValueRange(ivs).getTypes();
- auto whileOp = rewriter.create(loc, types, ivs);
- SmallVector l(types.size(), op.getIterator().getLoc());
-
- // Generates loop conditions.
- Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
- rewriter.setInsertionPointToStart(before);
- ValueRange bArgs = before->getArguments();
- auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
- assert(remArgs.size() == adaptor.getInitArgs().size());
- rewriter.create(loc, whileCond,
before->getArguments());
-
- // Generates loop body.
- Block *loopBody = op.getBody();
- OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(
- loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
- rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
- Region &dstRegion = whileOp.getAfter();
- // TODO: handle uses of coordinate!
- rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
- ValueRange aArgs = whileOp.getAfterArguments();
- auto yieldOp = llvm::cast(
- whileOp.getAfterBody()->getTerminator());
-
- rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+SmallVector ivs;
+for (ValueRange inits : adaptor.getInitArgs())
+ llvm::append_range(ivs, inits);
+
+// Type conversion on iterate op block.
+OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+if (failed(typeConverter->convertSignatureArgs(
+op.getBody()->getArgumentTypes(), blockTypeMapping)))
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert iterate region argurment types");
+rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+Block *block = op.getBody();
+ValueRange ret = genLoopWithIterator(
+rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+SparseIterator *it, ValueRange reduc) -> SmallVector {
+ SmallVector blockArgs(it->getCursor());
+ // TODO: Also appends coordinates if used.
+ // blockArgs.push_back(it->deref(rewriter, loc));
+
[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
https://github.com/PeimingLiu edited https://github.com/llvm/llvm-project/pull/105566 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu edited https://github.com/llvm/llvm-project/pull/105567 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu updated
https://github.com/llvm/llvm-project/pull/105567
>From 58bae5cff0b813347512a67a89e3abf6637ad0a9 Mon Sep 17 00:00:00 2001
From: Peiming Liu
Date: Thu, 15 Aug 2024 21:10:37 +
Subject: [PATCH] [mlir][sparse] unify block arguments order between
iterate/coiterate operations.
stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch:
users/PeimingLiu/stack/3
---
.../SparseTensor/IR/SparseTensorOps.td| 7 ++--
.../SparseTensor/IR/SparseTensorDialect.cpp | 31
.../Transforms/SparseIterationToScf.cpp | 36 ++-
3 files changed, 31 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
return getIterSpace().getType().getSpaceDim();
}
BlockArgument getIterator() {
- return getRegion().getArguments().front();
+ return getRegion().getArguments().back();
}
std::optional getLvlCrd(Level lvl) {
if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
return std::nullopt;
}
Block::BlockArgListType getCrds() {
- // The first block argument is iterator, the remaining arguments are
- // referenced coordinates.
- return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ // User-provided iteration arguments -> coords -> iterator.
+ return getRegion().getArguments().slice(getNumRegionIterArgs(),
getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs() {
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
parser.getNameLoc(),
"mismatch in number of sparse iterators and sparse spaces");
- if (failed(parseUsedCoordList(parser, state, blockArgs)))
+ SmallVector coords;
+ if (failed(parseUsedCoordList(parser, state, coords)))
return failure();
- size_t numCrds = blockArgs.size();
+ size_t numCrds = coords.size();
// Parse "iter_args(%arg = %init, ...)"
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
if (parser.parseAssignmentList(blockArgs, initArgs))
return failure();
+ blockArgs.append(coords);
+
SmallVector iterSpaceTps;
// parse ": sparse_tensor.iter_space -> ret"
if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser,
OperationState &state,
if (hasIterArgs) {
// Strip off leading args that used for coordinates.
-MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder,
OperationState &odsState,
odsState.addTypes(initArgs.getTypes());
Block *bodyBlock = builder.createBlock(bodyRegion);
- // First argument, sparse iterator
- bodyBlock->addArgument(
- llvm::cast(iterSpace.getType()).getIteratorType(),
- odsState.location);
+ // Starts with a list of user-provided loop arguments.
+ for (Value v : initArgs)
+bodyBlock->addArgument(v.getType(), v.getLoc());
- // Followed by a list of used coordinates.
+ // Follows by a list of used coordinates.
for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
bodyBlock->addArgument(builder.getIndexType(), odsState.location);
- // Followed by a list of user-provided loop arguments.
- for (Value v : initArgs)
-bodyBlock->addArgument(v.getType(), v.getLoc());
+ // Ends with sparse iterator
+ bodyBlock->addArgument(
+ llvm::cast(iterSpace.getType()).getIteratorType(),
+ odsState.location);
}
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser,
OperationState &result) {
return parser.emitError(parser.getNameLoc(),
"expected only one iterator/iteration space");
- iters.append(iterArgs);
+ iterArgs.append(iters);
Region *body = r
[llvm-branch-commits] [mlir] [mlir][SparseTensor] Fix type conversion rule (PR #140350)
PeimingLiu wrote: Thx! https://github.com/llvm/llvm-project/pull/140350 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][SparseTensor] Fix type conversion rule (PR #140350)
https://github.com/PeimingLiu approved this pull request. https://github.com/llvm/llvm-project/pull/140350 ___ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
