[Mlir-commits] [mlir] 7c74a25 - [mlir][SCF][NFC] Add helper functions to get body of scf.while

Matthias Springer llvmlistbot at llvm.org
Mon Aug 14 06:02:58 PDT 2023


Author: Matthias Springer
Date: 2023-08-14T14:57:09+02:00
New Revision: 7c74a2507c876a5cfc6c2b88c67a499cebd2bdb0

URL: https://github.com/llvm/llvm-project/commit/7c74a2507c876a5cfc6c2b88c67a499cebd2bdb0
DIFF: https://github.com/llvm/llvm-project/commit/7c74a2507c876a5cfc6c2b88c67a499cebd2bdb0.diff

LOG: [mlir][SCF][NFC] Add helper functions to get body of scf.while

Add two new helper functions `getBeforeBody` and `getAfterBody` to be consistent with "scf.for" (`getBody`) and to show in the API that both regions have exactly one block. Also simplify some code that assumed that there can be more than one block in a region.

Differential Revision: https://reviews.llvm.org/D157860

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index dd9a350d64561c..0915695bf5df34 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -1110,6 +1110,8 @@ def WhileOp : SCF_Op<"while",
     YieldOp getYieldOp();
     Block::BlockArgListType getBeforeArguments();
     Block::BlockArgListType getAfterArguments();
+    Block *getBeforeBody() { return &getBefore().front(); }
+    Block *getAfterBody() { return &getAfter().front(); }
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 57ba18bd53f1b5..51849518be6649 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -542,10 +542,8 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
       rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
 
   // Inline both regions.
-  Block *after = &whileOp.getAfter().front();
-  Block *afterLast = &whileOp.getAfter().back();
-  Block *before = &whileOp.getBefore().front();
-  Block *beforeLast = &whileOp.getBefore().back();
+  Block *after = whileOp.getAfterBody();
+  Block *before = whileOp.getBeforeBody();
   rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
   rewriter.inlineRegionBefore(whileOp.getBefore(), after);
 
@@ -556,14 +554,14 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
   // Replace terminators with branches. Assuming bodies are SESE, which holds
   // given only the patterns from this file, we only need to look at the last
   // block. This should be reconsidered if we allow break/continue in SCF.
-  rewriter.setInsertionPointToEnd(beforeLast);
-  auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
+  rewriter.setInsertionPointToEnd(before);
+  auto condOp = cast<ConditionOp>(before->getTerminator());
   rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
                                                 after, condOp.getArgs(),
                                                 continuation, ValueRange());
 
-  rewriter.setInsertionPointToEnd(afterLast);
-  auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
+  rewriter.setInsertionPointToEnd(after);
+  auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
   rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
                                             yieldOp.getResults());
 
@@ -577,12 +575,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
 LogicalResult
 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
                                  PatternRewriter &rewriter) const {
-  if (!llvm::hasSingleElement(whileOp.getAfter()))
-    return rewriter.notifyMatchFailure(whileOp,
-                                       "do-while simplification applicable to "
-                                       "single-block 'after' region only");
-
-  Block &afterBlock = whileOp.getAfter().front();
+  Block &afterBlock = *whileOp.getAfterBody();
   if (!llvm::hasSingleElement(afterBlock))
     return rewriter.notifyMatchFailure(whileOp,
                                        "do-while simplification applicable "
@@ -601,8 +594,7 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
       rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
 
   // Only the "before" region should be inlined.
-  Block *before = &whileOp.getBefore().front();
-  Block *beforeLast = &whileOp.getBefore().back();
+  Block *before = whileOp.getBeforeBody();
   rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
 
   // Branch to the "before" region.
@@ -610,8 +602,8 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
   rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
 
   // Loop around the "before" region based on condition.
-  rewriter.setInsertionPointToEnd(beforeLast);
-  auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
+  rewriter.setInsertionPointToEnd(before);
+  auto condOp = cast<ConditionOp>(before->getTerminator());
   rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
                                                 before, condOp.getArgs(),
                                                 continuation, ValueRange());

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b4dae244825364..63ce3b2a469627 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3177,19 +3177,19 @@ OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
 }
 
 ConditionOp WhileOp::getConditionOp() {
-  return cast<ConditionOp>(getBefore().front().getTerminator());
+  return cast<ConditionOp>(getBeforeBody()->getTerminator());
 }
 
 YieldOp WhileOp::getYieldOp() {
-  return cast<YieldOp>(getAfter().front().getTerminator());
+  return cast<YieldOp>(getAfterBody()->getTerminator());
 }
 
 Block::BlockArgListType WhileOp::getBeforeArguments() {
-  return getBefore().front().getArguments();
+  return getBeforeBody()->getArguments();
 }
 
 Block::BlockArgListType WhileOp::getAfterArguments() {
-  return getAfter().front().getArguments();
+  return getAfterBody()->getArguments();
 }
 
 void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
@@ -3260,8 +3260,7 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
 /// Prints a `while` op.
 void scf::WhileOp::print(OpAsmPrinter &p) {
-  printInitializationList(p, getBefore().front().getArguments(), getInits(),
-                          " ");
+  printInitializationList(p, getBeforeArguments(), getInits(), " ");
   p << " : ";
   p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
   p << ' ';
@@ -3411,7 +3410,7 @@ struct RemoveLoopInvariantArgsFromBeforeBlock
 
   LogicalResult matchAndRewrite(WhileOp op,
                                 PatternRewriter &rewriter) const override {
-    Block &afterBlock = op.getAfter().front();
+    Block &afterBlock = *op.getAfterBody();
     Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
     ConditionOp condOp = op.getConditionOp();
     OperandRange condOpArgs = condOp.getArgs();
@@ -3493,7 +3492,7 @@ struct RemoveLoopInvariantArgsFromBeforeBlock
         &newWhile.getBefore(), /*insertPt*/ {},
         ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
 
-    Block &beforeBlock = op.getBefore().front();
+    Block &beforeBlock = *op.getBeforeBody();
     SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
     // For each i-th before block argument we find it's replacement value as :-
     //   1. If i-th before block argument is a loop invariant, we fetch it's
@@ -3563,7 +3562,7 @@ struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
 
   LogicalResult matchAndRewrite(WhileOp op,
                                 PatternRewriter &rewriter) const override {
-    Block &beforeBlock = op.getBefore().front();
+    Block &beforeBlock = *op.getBeforeBody();
     ConditionOp condOp = op.getConditionOp();
     OperandRange condOpArgs = condOp.getArgs();
 
@@ -3616,7 +3615,7 @@ struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
         *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
                               newAfterBlockType, newAfterBlockArgLocs);
 
-    Block &afterBlock = op.getAfter().front();
+    Block &afterBlock = *op.getAfterBody();
     // Since a new scf.condition op was created, we need to fetch the new
     // `after` block arguments which will be used while replacing operations of
     // previous scf.while's `after` blocks. We'd also be fetching new result
@@ -3733,7 +3732,7 @@ struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
                                 newWhile.getBefore().begin());
 
-    Block &afterBlock = op.getAfter().front();
+    Block &afterBlock = *op.getAfterBody();
     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
 
     rewriter.replaceOp(op, newResults);
@@ -3774,8 +3773,7 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
     if (!cmp)
       return failure();
     bool changed = false;
-    for (auto tup :
-         llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
+    for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
       for (size_t opIdx = 0; opIdx < 2; opIdx++) {
         if (std::get<0>(tup) != cmp.getOperand(opIdx))
           continue;
@@ -3839,8 +3837,8 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
       }
     }
 
-    Block &beforeBlock = op.getBefore().front();
-    Block &afterBlock = op.getAfter().front();
+    Block &beforeBlock = *op.getBeforeBody();
+    Block &afterBlock = *op.getAfterBody();
 
     beforeBlock.eraseArguments(argsToErase);
 
@@ -3848,8 +3846,8 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
     auto newWhileOp =
         rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
                                  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
-    Block &newBeforeBlock = newWhileOp.getBefore().front();
-    Block &newAfterBlock = newWhileOp.getAfter().front();
+    Block &newBeforeBlock = *newWhileOp.getBeforeBody();
+    Block &newAfterBlock = *newWhileOp.getAfterBody();
 
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(yield);
@@ -3899,8 +3897,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
     auto newWhileOp = rewriter.create<scf::WhileOp>(
         loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
         /*afterBody*/ nullptr);
-    Block &newBeforeBlock = newWhileOp.getBefore().front();
-    Block &newAfterBlock = newWhileOp.getAfter().front();
+    Block &newBeforeBlock = *newWhileOp.getBeforeBody();
+    Block &newAfterBlock = *newWhileOp.getAfterBody();
 
     SmallVector<Value> afterArgsMapping;
     SmallVector<Value> resultsMapping;
@@ -3917,8 +3915,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
     rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
                                              argsRange);
 
-    Block &beforeBlock = op.getBefore().front();
-    Block &afterBlock = op.getAfter().front();
+    Block &beforeBlock = *op.getBeforeBody();
+    Block &afterBlock = *op.getAfterBody();
 
     rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
                          newBeforeBlock.getArguments());

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b6cb8c7f438e6f..677cd04802d3d9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -760,13 +760,6 @@ struct WhileOpInterface
                           const BufferizationOptions &options) const {
     auto whileOp = cast<scf::WhileOp>(op);
 
-    assert(whileOp.getBefore().getBlocks().size() == 1 &&
-           "regions with multiple blocks not supported");
-    Block *beforeBody = &whileOp.getBefore().front();
-    assert(whileOp.getAfter().getBlocks().size() == 1 &&
-           "regions with multiple blocks not supported");
-    Block *afterBody = &whileOp.getAfter().front();
-
     // Indices of all bbArgs that have tensor type. These are the ones that
     // are bufferized. The "before" and "after" regions may have 
diff erent args.
     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
@@ -827,7 +820,7 @@ struct WhileOpInterface
     rewriter.setInsertionPointToStart(newBeforeBody);
     SmallVector<Value> newBeforeArgs = getBbArgReplacements(
         rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
-    rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
+    rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
 
     // Set up new iter_args and move the loop body block to the new op.
     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
@@ -835,7 +828,7 @@ struct WhileOpInterface
     rewriter.setInsertionPointToStart(newAfterBody);
     SmallVector<Value> newAfterArgs = getBbArgReplacements(
         rewriter, newWhileOp.getAfterArguments(), indicesAfter);
-    rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
+    rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
 
     // Replace loop results.
     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index 8863b0833d3e72..7b6b07eabf6c48 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -57,7 +57,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     // arguments to the 'after' region.
     auto *beforeBlock = rewriter.createBlock(
         &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
-    rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
+    rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
     auto cmpOp = rewriter.create<arith::CmpIOp>(
         whileOp.getLoc(), arith::CmpIPredicate::slt,
         beforeBlock->getArgument(0), forOp.getUpperBound());


        


More information about the Mlir-commits mailing list