[Mlir-commits] [mlir] 8475fa6 - [mlir] Add a simpler lowering pattern for WhileOp representing a do-while loop

Alex Zinenko llvmlistbot at llvm.org
Wed Nov 4 00:44:36 PST 2020


Author: Alex Zinenko
Date: 2020-11-04T09:43:13+01:00
New Revision: 8475fa6ed6bb27d5abad418a7f77e9430aa825eb

URL: https://github.com/llvm/llvm-project/commit/8475fa6ed6bb27d5abad418a7f77e9430aa825eb
DIFF: https://github.com/llvm/llvm-project/commit/8475fa6ed6bb27d5abad418a7f77e9430aa825eb.diff

LOG: [mlir] Add a simpler lowering pattern for WhileOp representing a do-while loop

When the "after" region of a WhileOp is merely forwarding its arguments back to
the "before" region, i.e. WhileOp is a canonical do-while loop, a simpler CFG
subgraph that omits the "after" region with its extra branch operation can be
produced. Loop rotation from general "while" to "if { do-while }" is left for a
future canonicalization pattern when it becomes necessary.

Differential Revision: https://reviews.llvm.org/D90604

Added: 
    

Modified: 
    mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
    mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
index 953cb27eee74..425131f91a28 100644
--- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
+++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
@@ -266,6 +266,17 @@ struct WhileLowering : public OpRewritePattern<WhileOp> {
   LogicalResult matchAndRewrite(WhileOp whileOp,
                                 PatternRewriter &rewriter) const override;
 };
+
+/// Optimized version of the above for the case of the "after" region merely
+/// forwarding its arguments back to the "before" region (i.e., a "do-while"
+/// loop). This avoid inlining the "after" region completely and branches back
+/// to the "before" entry instead.
+struct DoWhileLowering : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp whileOp,
+                                PatternRewriter &rewriter) const override;
+};
 } // namespace
 
 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -507,10 +518,60 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
   return success();
 }
 
+LogicalResult
+DoWhileLowering::matchAndRewrite(WhileOp whileOp,
+                                 PatternRewriter &rewriter) const {
+  if (!llvm::hasSingleElement(whileOp.after()))
+    return rewriter.notifyMatchFailure(whileOp,
+                                       "do-while simplification applicable to "
+                                       "single-block 'after' region only");
+
+  Block &afterBlock = whileOp.after().front();
+  if (!llvm::hasSingleElement(afterBlock))
+    return rewriter.notifyMatchFailure(whileOp,
+                                       "do-while simplification applicable "
+                                       "only if 'after' region has no payload");
+
+  auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
+  if (!yield || yield.results() != afterBlock.getArguments())
+    return rewriter.notifyMatchFailure(whileOp,
+                                       "do-while simplification applicable "
+                                       "only to forwarding 'after' regions");
+
+  // Split the current block before the WhileOp to create the inlining point.
+  OpBuilder::InsertionGuard guard(rewriter);
+  Block *currentBlock = rewriter.getInsertionBlock();
+  Block *continuation =
+      rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
+
+  // Only the "before" region should be inlined.
+  Block *before = &whileOp.before().front();
+  Block *beforeLast = &whileOp.before().back();
+  rewriter.inlineRegionBefore(whileOp.before(), continuation);
+
+  // Branch to the "before" region.
+  rewriter.setInsertionPointToEnd(currentBlock);
+  rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.inits());
+
+  // Loop around the "before" region based on condition.
+  rewriter.setInsertionPointToEnd(beforeLast);
+  auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
+  rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.condition(), before,
+                                            condOp.args(), continuation,
+                                            ValueRange());
+
+  // Replace the op with values "yielded" from the "before" region, which are
+  // visible by dominance.
+  rewriter.replaceOp(whileOp, condOp.args());
+
+  return success();
+}
+
 void mlir::populateLoopToStdConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
   patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
       ctx);
+  patterns.insert<DoWhileLowering>(ctx, /*benefit=*/2);
 }
 
 void SCFToStandardPass::runOnOperation() {

diff  --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
index 08ad5f1d8976..c3f1325a549b 100644
--- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
@@ -424,6 +424,8 @@ func @minimal_while() {
     scf.condition(%0)
   } do {
   // CHECK: ^[[AFTER]]:
+  // CHECK:   "test.some_payload"() : () -> ()
+    "test.some_payload"() : () -> ()
   // CHECK:   br ^[[BEFORE]]
     scf.yield
   }
@@ -432,6 +434,25 @@ func @minimal_while() {
   return
 }
 
+// CHECK-LABEL: @do_while
+func @do_while(%arg0: f32) {
+  // CHECK:   br ^[[BEFORE:.*]]({{.*}}: f32)
+  scf.while (%arg1 = %arg0) : (f32) -> (f32) {
+  // CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32):
+    // CHECK:   %[[COND:.*]] = "test.make_condition"() : () -> i1
+    %0 = "test.make_condition"() : () -> i1
+    // CHECK:   cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
+    scf.condition(%0) %arg1 : f32
+  } do {
+  ^bb0(%arg2: f32):
+    // CHECK-NOT: br ^[[BEFORE]]
+    scf.yield %arg2 : f32
+  }
+  // CHECK: ^[[CONT]]:
+  // CHECK:   return
+  return
+}
+
 // CHECK-LABEL: @while_values
 // CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
 func @while_values(%arg0: i32, %arg1: f32) {


        


More information about the Mlir-commits mailing list