[Mlir-commits] [mlir] 7c74a25 - [mlir][SCF][NFC] Add helper functions to get body of scf.while
Matthias Springer
llvmlistbot at llvm.org
Mon Aug 14 06:02:58 PDT 2023
Author: Matthias Springer
Date: 2023-08-14T14:57:09+02:00
New Revision: 7c74a2507c876a5cfc6c2b88c67a499cebd2bdb0
URL: https://github.com/llvm/llvm-project/commit/7c74a2507c876a5cfc6c2b88c67a499cebd2bdb0
DIFF: https://github.com/llvm/llvm-project/commit/7c74a2507c876a5cfc6c2b88c67a499cebd2bdb0.diff
LOG: [mlir][SCF][NFC] Add helper functions to get body of scf.while
Add two new helper functions `getBeforeBody` and `getAfterBody` to be consistent with "scf.for" (`getBody`) and to show in the API that both regions have exactly one block. Also simplify some code that assumed that there can be more than one block in a region.
Differential Revision: https://reviews.llvm.org/D157860
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index dd9a350d64561c..0915695bf5df34 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -1110,6 +1110,8 @@ def WhileOp : SCF_Op<"while",
YieldOp getYieldOp();
Block::BlockArgListType getBeforeArguments();
Block::BlockArgListType getAfterArguments();
+ Block *getBeforeBody() { return &getBefore().front(); }
+ Block *getAfterBody() { return &getAfter().front(); }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 57ba18bd53f1b5..51849518be6649 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -542,10 +542,8 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
// Inline both regions.
- Block *after = &whileOp.getAfter().front();
- Block *afterLast = &whileOp.getAfter().back();
- Block *before = &whileOp.getBefore().front();
- Block *beforeLast = &whileOp.getBefore().back();
+ Block *after = whileOp.getAfterBody();
+ Block *before = whileOp.getBeforeBody();
rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
rewriter.inlineRegionBefore(whileOp.getBefore(), after);
@@ -556,14 +554,14 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// Replace terminators with branches. Assuming bodies are SESE, which holds
// given only the patterns from this file, we only need to look at the last
// block. This should be reconsidered if we allow break/continue in SCF.
- rewriter.setInsertionPointToEnd(beforeLast);
- auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
+ rewriter.setInsertionPointToEnd(before);
+ auto condOp = cast<ConditionOp>(before->getTerminator());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(),
continuation, ValueRange());
- rewriter.setInsertionPointToEnd(afterLast);
- auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
+ rewriter.setInsertionPointToEnd(after);
+ auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
yieldOp.getResults());
@@ -577,12 +575,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
LogicalResult
DoWhileLowering::matchAndRewrite(WhileOp whileOp,
PatternRewriter &rewriter) const {
- if (!llvm::hasSingleElement(whileOp.getAfter()))
- return rewriter.notifyMatchFailure(whileOp,
- "do-while simplification applicable to "
- "single-block 'after' region only");
-
- Block &afterBlock = whileOp.getAfter().front();
+ Block &afterBlock = *whileOp.getAfterBody();
if (!llvm::hasSingleElement(afterBlock))
return rewriter.notifyMatchFailure(whileOp,
"do-while simplification applicable "
@@ -601,8 +594,7 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
// Only the "before" region should be inlined.
- Block *before = &whileOp.getBefore().front();
- Block *beforeLast = &whileOp.getBefore().back();
+ Block *before = whileOp.getBeforeBody();
rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
// Branch to the "before" region.
@@ -610,8 +602,8 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
// Loop around the "before" region based on condition.
- rewriter.setInsertionPointToEnd(beforeLast);
- auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
+ rewriter.setInsertionPointToEnd(before);
+ auto condOp = cast<ConditionOp>(before->getTerminator());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
before, condOp.getArgs(),
continuation, ValueRange());
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b4dae244825364..63ce3b2a469627 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3177,19 +3177,19 @@ OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
}
ConditionOp WhileOp::getConditionOp() {
- return cast<ConditionOp>(getBefore().front().getTerminator());
+ return cast<ConditionOp>(getBeforeBody()->getTerminator());
}
YieldOp WhileOp::getYieldOp() {
- return cast<YieldOp>(getAfter().front().getTerminator());
+ return cast<YieldOp>(getAfterBody()->getTerminator());
}
Block::BlockArgListType WhileOp::getBeforeArguments() {
- return getBefore().front().getArguments();
+ return getBeforeBody()->getArguments();
}
Block::BlockArgListType WhileOp::getAfterArguments() {
- return getAfter().front().getArguments();
+ return getAfterBody()->getArguments();
}
void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
@@ -3260,8 +3260,7 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
/// Prints a `while` op.
void scf::WhileOp::print(OpAsmPrinter &p) {
- printInitializationList(p, getBefore().front().getArguments(), getInits(),
- " ");
+ printInitializationList(p, getBeforeArguments(), getInits(), " ");
p << " : ";
p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
p << ' ';
@@ -3411,7 +3410,7 @@ struct RemoveLoopInvariantArgsFromBeforeBlock
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
- Block &afterBlock = op.getAfter().front();
+ Block &afterBlock = *op.getAfterBody();
Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
ConditionOp condOp = op.getConditionOp();
OperandRange condOpArgs = condOp.getArgs();
@@ -3493,7 +3492,7 @@ struct RemoveLoopInvariantArgsFromBeforeBlock
&newWhile.getBefore(), /*insertPt*/ {},
ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
- Block &beforeBlock = op.getBefore().front();
+ Block &beforeBlock = *op.getBeforeBody();
SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
// For each i-th before block argument we find it's replacement value as :-
// 1. If i-th before block argument is a loop invariant, we fetch it's
@@ -3563,7 +3562,7 @@ struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
- Block &beforeBlock = op.getBefore().front();
+ Block &beforeBlock = *op.getBeforeBody();
ConditionOp condOp = op.getConditionOp();
OperandRange condOpArgs = condOp.getArgs();
@@ -3616,7 +3615,7 @@ struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
*rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
newAfterBlockType, newAfterBlockArgLocs);
- Block &afterBlock = op.getAfter().front();
+ Block &afterBlock = *op.getAfterBody();
// Since a new scf.condition op was created, we need to fetch the new
// `after` block arguments which will be used while replacing operations of
// previous scf.while's `after` blocks. We'd also be fetching new result
@@ -3733,7 +3732,7 @@ struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
newWhile.getBefore().begin());
- Block &afterBlock = op.getAfter().front();
+ Block &afterBlock = *op.getAfterBody();
rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
rewriter.replaceOp(op, newResults);
@@ -3774,8 +3773,7 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
if (!cmp)
return failure();
bool changed = false;
- for (auto tup :
- llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
+ for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
for (size_t opIdx = 0; opIdx < 2; opIdx++) {
if (std::get<0>(tup) != cmp.getOperand(opIdx))
continue;
@@ -3839,8 +3837,8 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
}
}
- Block &beforeBlock = op.getBefore().front();
- Block &afterBlock = op.getAfter().front();
+ Block &beforeBlock = *op.getBeforeBody();
+ Block &afterBlock = *op.getAfterBody();
beforeBlock.eraseArguments(argsToErase);
@@ -3848,8 +3846,8 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
auto newWhileOp =
rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
/*beforeBody*/ nullptr, /*afterBody*/ nullptr);
- Block &newBeforeBlock = newWhileOp.getBefore().front();
- Block &newAfterBlock = newWhileOp.getAfter().front();
+ Block &newBeforeBlock = *newWhileOp.getBeforeBody();
+ Block &newAfterBlock = *newWhileOp.getAfterBody();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(yield);
@@ -3899,8 +3897,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
auto newWhileOp = rewriter.create<scf::WhileOp>(
loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
/*afterBody*/ nullptr);
- Block &newBeforeBlock = newWhileOp.getBefore().front();
- Block &newAfterBlock = newWhileOp.getAfter().front();
+ Block &newBeforeBlock = *newWhileOp.getBeforeBody();
+ Block &newAfterBlock = *newWhileOp.getAfterBody();
SmallVector<Value> afterArgsMapping;
SmallVector<Value> resultsMapping;
@@ -3917,8 +3915,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
argsRange);
- Block &beforeBlock = op.getBefore().front();
- Block &afterBlock = op.getAfter().front();
+ Block &beforeBlock = *op.getBeforeBody();
+ Block &afterBlock = *op.getAfterBody();
rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
newBeforeBlock.getArguments());
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b6cb8c7f438e6f..677cd04802d3d9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -760,13 +760,6 @@ struct WhileOpInterface
const BufferizationOptions &options) const {
auto whileOp = cast<scf::WhileOp>(op);
- assert(whileOp.getBefore().getBlocks().size() == 1 &&
- "regions with multiple blocks not supported");
- Block *beforeBody = &whileOp.getBefore().front();
- assert(whileOp.getAfter().getBlocks().size() == 1 &&
- "regions with multiple blocks not supported");
- Block *afterBody = &whileOp.getAfter().front();
-
// Indices of all bbArgs that have tensor type. These are the ones that
// are bufferized. The "before" and "after" regions may have
diff erent args.
DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
@@ -827,7 +820,7 @@ struct WhileOpInterface
rewriter.setInsertionPointToStart(newBeforeBody);
SmallVector<Value> newBeforeArgs = getBbArgReplacements(
rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
- rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
+ rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
// Set up new iter_args and move the loop body block to the new op.
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
@@ -835,7 +828,7 @@ struct WhileOpInterface
rewriter.setInsertionPointToStart(newAfterBody);
SmallVector<Value> newAfterArgs = getBbArgReplacements(
rewriter, newWhileOp.getAfterArguments(), indicesAfter);
- rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
+ rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
// Replace loop results.
replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index 8863b0833d3e72..7b6b07eabf6c48 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -57,7 +57,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
// arguments to the 'after' region.
auto *beforeBlock = rewriter.createBlock(
&whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
- rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
+ rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
auto cmpOp = rewriter.create<arith::CmpIOp>(
whileOp.getLoc(), arith::CmpIPredicate::slt,
beforeBlock->getArgument(0), forOp.getUpperBound());
More information about the Mlir-commits
mailing list