[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
Peiming Liu via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Aug 21 11:30:14 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105567
>From 3f83d7a1eadc1101fb96707ecd348925e5aaed70 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 15 Aug 2024 21:10:37 +0000
Subject: [PATCH] [mlir][sparse] unify block arguments order between
iterate/coiterate operations.
stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch: users/PeimingLiu/stack/3
---
.../SparseTensor/IR/SparseTensorOps.td | 7 ++--
.../SparseTensor/IR/SparseTensorDialect.cpp | 31 ++++++++--------
.../Transforms/SparseIterationToScf.cpp | 36 ++++++-------------
3 files changed, 31 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
return getIterSpace().getType().getSpaceDim();
}
BlockArgument getIterator() {
- return getRegion().getArguments().front();
+ return getRegion().getArguments().back();
}
std::optional<BlockArgument> getLvlCrd(Level lvl) {
if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
return std::nullopt;
}
Block::BlockArgListType getCrds() {
- // The first block argument is iterator, the remaining arguments are
- // referenced coordinates.
- return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ // User-provided iteration arguments -> coords -> iterator.
+ return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs() {
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
parser.getNameLoc(),
"mismatch in number of sparse iterators and sparse spaces");
- if (failed(parseUsedCoordList(parser, state, blockArgs)))
+ SmallVector<OpAsmParser::Argument> coords;
+ if (failed(parseUsedCoordList(parser, state, coords)))
return failure();
- size_t numCrds = blockArgs.size();
+ size_t numCrds = coords.size();
// Parse "iter_args(%arg = %init, ...)"
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
if (parser.parseAssignmentList(blockArgs, initArgs))
return failure();
+ blockArgs.append(coords);
+
SmallVector<Type> iterSpaceTps;
// parse ": sparse_tensor.iter_space -> ret"
if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
if (hasIterArgs) {
// Strip off leading args that used for coordinates.
- MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+ MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState,
odsState.addTypes(initArgs.getTypes());
Block *bodyBlock = builder.createBlock(bodyRegion);
- // First argument, sparse iterator
- bodyBlock->addArgument(
- llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
- odsState.location);
+ // Starts with a list of user-provided loop arguments.
+ for (Value v : initArgs)
+ bodyBlock->addArgument(v.getType(), v.getLoc());
- // Followed by a list of used coordinates.
+ // Follows by a list of used coordinates.
for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
bodyBlock->addArgument(builder.getIndexType(), odsState.location);
- // Followed by a list of user-provided loop arguments.
- for (Value v : initArgs)
- bodyBlock->addArgument(v.getType(), v.getLoc());
+ // Ends with sparse iterator
+ bodyBlock->addArgument(
+ llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
+ odsState.location);
}
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
return parser.emitError(parser.getNameLoc(),
"expected only one iterator/iteration space");
- iters.append(iterArgs);
+ iterArgs.append(iters);
Region *body = result.addRegion();
- if (parser.parseRegion(*body, iters))
+ if (parser.parseRegion(*body, iterArgs))
return failure();
IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2580,7 +2583,7 @@ MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
}
Block::BlockArgListType IterateOp::getRegionIterArgs() {
- return getRegion().getArguments().take_back(getNumRegionIterArgs());
+ return getRegion().getArguments().take_front(getNumRegionIterArgs());
}
std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index f7fcabb0220b50..71a229bea990c0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -111,7 +111,7 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
static ValueRange genLoopWithIterator(
PatternRewriter &rewriter, Location loc, SparseIterator *it,
- ValueRange reduc, bool iterFirst,
+ ValueRange reduc,
function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
Region &loopBody, SparseIterator *it,
ValueRange reduc)>
@@ -138,15 +138,9 @@ static ValueRange genLoopWithIterator(
}
return forOp.getResults();
}
- SmallVector<Value> ivs;
- // TODO: always put iterator SSA values at the end of argument list to be
- // consistent with coiterate operation.
- if (!iterFirst)
- llvm::append_range(ivs, it->getCursor());
- // Appends the user-provided values.
- llvm::append_range(ivs, reduc);
- if (iterFirst)
- llvm::append_range(ivs, it->getCursor());
+
+ SmallVector<Value> ivs(reduc);
+ llvm::append_range(ivs, it->getCursor());
TypeRange types = ValueRange(ivs).getTypes();
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
@@ -164,12 +158,8 @@ static ValueRange genLoopWithIterator(
Region &dstRegion = whileOp.getAfter();
Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
ValueRange aArgs = whileOp.getAfterArguments();
- if (iterFirst) {
- aArgs = it->linkNewScope(aArgs);
- } else {
- aArgs = aArgs.take_front(reduc.size());
- it->linkNewScope(aArgs.drop_front(reduc.size()));
- }
+ it->linkNewScope(aArgs.drop_front(reduc.size()));
+ aArgs = aArgs.take_front(reduc.size());
rewriter.setInsertionPointToStart(after);
SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
@@ -177,12 +167,8 @@ static ValueRange genLoopWithIterator(
// Forward loops
SmallVector<Value> yields;
- ValueRange nx = it->forward(rewriter, loc);
- if (iterFirst)
- llvm::append_range(yields, nx);
llvm::append_range(yields, ret);
- if (!iterFirst)
- llvm::append_range(yields, nx);
+ llvm::append_range(yields, it->forward(rewriter, loc));
rewriter.create<scf::YieldOp>(loc, yields);
}
return whileOp.getResults().drop_front(it->getCursor().size());
@@ -258,13 +244,13 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
Block *block = op.getBody();
ValueRange ret = genLoopWithIterator(
- rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+ rewriter, loc, it.get(), ivs,
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
- SmallVector<Value> blockArgs(it->getCursor());
+ SmallVector<Value> blockArgs(reduc);
// TODO: Also appends coordinates if used.
// blockArgs.push_back(it->deref(rewriter, loc));
- llvm::append_range(blockArgs, reduc);
+ llvm::append_range(blockArgs, it->getCursor());
Block *dstBlock = &loopBody.getBlocks().front();
rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
@@ -404,7 +390,7 @@ class SparseCoIterateOpConverter
Block *block = &r.getBlocks().front();
ValueRange curResult = genLoopWithIterator(
- rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
+ rewriter, loc, validIters.front(), userReduc,
/*bodyBuilder=*/
[block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
SparseIterator *it,
More information about the llvm-branch-commits
mailing list