[Mlir-commits] [mlir] 4c0e255 - [mlir] Add lowering to CFG for WhileOp

Alex Zinenko llvmlistbot at llvm.org
Wed Nov 4 00:44:34 PST 2020


Author: Alex Zinenko
Date: 2020-11-04T09:43:13+01:00
New Revision: 4c0e255c98cc0e7769be9c9b2700d96e76aec99f

URL: https://github.com/llvm/llvm-project/commit/4c0e255c98cc0e7769be9c9b2700d96e76aec99f
DIFF: https://github.com/llvm/llvm-project/commit/4c0e255c98cc0e7769be9c9b2700d96e76aec99f.diff

LOG: [mlir] Add lowering to CFG for WhileOp

The lowering is a straightforward inlining of the "before" and "after" regions
connected by (conditional) branches. This plugs the WhileOp into the
progressive lowering scheme. Future commits may choose to target WhileOp
instead of CFG when lowering ForOp.

Differential Revision: https://reviews.llvm.org/D90603

Added: 
    

Modified: 
    mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
    mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
index 625077a28aac..953cb27eee74 100644
--- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
+++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
@@ -200,6 +200,72 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
   LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override;
 };
+
+/// Create a CFG subgraph for this loop construct. The regions of the loop need
+/// not be a single block anymore (for example, if other SCF constructs that
+/// they contain have been already converted to CFG), but need to be single-exit
+/// from the last block of each region. The operations following the original
+/// WhileOp are split into a new continuation block. Both regions of the WhileOp
+/// are inlined, and their terminators are rewritten to organize the control
+/// flow implementing the loop as follows.
+///
+///      +---------------------------------+
+///      |   <code before the WhileOp>     |
+///      |   br ^before(%operands...)      |
+///      +---------------------------------+
+///             |
+///  -------|   |
+///  |      v   v
+///  |   +--------------------------------+
+///  |   | ^before(%bargs...):            |
+///  |   |   %vals... = <some payload>    |
+///  |   +--------------------------------+
+///  |                   |
+///  |                  ...
+///  |                   |
+///  |   +--------------------------------+
+///  |   | ^before-last:
+///  |   |   %cond = <compute condition>  |
+///  |   |   cond_br %cond,               |
+///  |   |        ^after(%vals...), ^cont |
+///  |   +--------------------------------+
+///  |          |               |
+///  |          |               -------------|
+///  |          v                            |
+///  |   +--------------------------------+  |
+///  |   | ^after(%aargs...):             |  |
+///  |   |   <body contents>              |  |
+///  |   +--------------------------------+  |
+///  |                   |                   |
+///  |                  ...                  |
+///  |                   |                   |
+///  |   +--------------------------------+  |
+///  |   | ^after-last:                   |  |
+///  |   |   %yields... = <some payload>  |  |
+///  |   |   br ^before(%yields...)       |  |
+///  |   +--------------------------------+  |
+///  |          |                            |
+///  |-----------        |--------------------
+///                      v
+///      +--------------------------------+
+///      | ^cont:                         |
+///      |   <code after the WhileOp>     |
+///      |   <%vals from 'before' region  |
+///      |          visible by dominance> |
+///      +--------------------------------+
+///
+/// Values are communicated between ex-regions (the groups of blocks that used
+/// to form a region before inlining) through block arguments of their
+/// entry blocks, which are visible in all other dominated blocks. Similarly,
+/// the results of the WhileOp are defined in the 'before' region, which is
+/// required to have a single existing block, and are therefore accessible in
+/// the continuation block due to dominance.
+struct WhileLowering : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp whileOp,
+                                PatternRewriter &rewriter) const override;
+};
 } // namespace
 
 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -399,18 +465,61 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
   return success();
 }
 
+LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
+                                             PatternRewriter &rewriter) const {
+  OpBuilder::InsertionGuard guard(rewriter);
+  Location loc = whileOp.getLoc();
+
+  // Split the current block before the WhileOp to create the inlining point.
+  Block *currentBlock = rewriter.getInsertionBlock();
+  Block *continuation =
+      rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
+
+  // Inline both regions.
+  Block *after = &whileOp.after().front();
+  Block *afterLast = &whileOp.after().back();
+  Block *before = &whileOp.before().front();
+  Block *beforeLast = &whileOp.before().back();
+  rewriter.inlineRegionBefore(whileOp.after(), continuation);
+  rewriter.inlineRegionBefore(whileOp.before(), after);
+
+  // Branch to the "before" region.
+  rewriter.setInsertionPointToEnd(currentBlock);
+  rewriter.create<BranchOp>(loc, before, whileOp.inits());
+
+  // Replace terminators with branches. Assuming bodies are SESE, which holds
+  // given only the patterns from this file, we only need to look at the last
+  // block. This should be reconsidered if we allow break/continue in SCF.
+  rewriter.setInsertionPointToEnd(beforeLast);
+  auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
+  rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.condition(), after,
+                                            condOp.args(), continuation,
+                                            ValueRange());
+
+  rewriter.setInsertionPointToEnd(afterLast);
+  auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
+  rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, before, yieldOp.results());
+
+  // Replace the op with values "yielded" from the "before" region, which are
+  // visible by dominance.
+  rewriter.replaceOp(whileOp, condOp.args());
+
+  return success();
+}
+
 void mlir::populateLoopToStdConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<ForLowering, IfLowering, ParallelLowering>(ctx);
+  patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
+      ctx);
 }
 
 void SCFToStandardPass::runOnOperation() {
   OwningRewritePatternList patterns;
   populateLoopToStdConversionPatterns(patterns, &getContext());
-  // Configure conversion to lower out scf.for, scf.if and scf.parallel.
-  // Anything else is fine.
+  // Configure conversion to lower out scf.for, scf.if, scf.parallel and
+  // scf.while. Anything else is fine.
   ConversionTarget target(getContext());
-  target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp>();
+  target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))

diff  --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
index a6d22d982f18..08ad5f1d8976 100644
--- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
@@ -412,3 +412,116 @@ func @unknown_op_inside_loop(%arg0: index, %arg1: index, %arg2: index) {
   }
   return
 }
+
+// CHECK-LABEL: @minimal_while
+func @minimal_while() {
+  // CHECK:   %[[COND:.*]] = "test.make_condition"() : () -> i1
+  // CHECK:   br ^[[BEFORE:.*]]
+  %0 = "test.make_condition"() : () -> i1
+  scf.while : () -> () {
+  // CHECK: ^[[BEFORE]]:
+  // CHECK:   cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]]
+    scf.condition(%0)
+  } do {
+  // CHECK: ^[[AFTER]]:
+  // CHECK:   br ^[[BEFORE]]
+    scf.yield
+  }
+  // CHECK: ^[[CONT]]:
+  // CHECK:   return
+  return
+}
+
+// CHECK-LABEL: @while_values
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
+func @while_values(%arg0: i32, %arg1: f32) {
+  // CHECK:     %[[COND:.*]] = "test.make_condition"() : () -> i1
+  %0 = "test.make_condition"() : () -> i1
+  %c0_i32 = constant 0 : i32
+  %cst = constant 0.000000e+00 : f32
+  // CHECK:     br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32)
+  %1:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, f32) -> (i64, f64) {
+  // CHECK:   ^bb1(%[[ARG2:.*]]: i32, %[[ARG3:.]]: f32):
+    // CHECK:   %[[VAL1:.*]] = zexti %[[ARG0]] : i32 to i64
+    %2 = zexti %arg0 : i32 to i64
+    // CHECK:   %[[VAL2:.*]] = fpext %[[ARG3]] : f32 to f64
+    %3 = fpext %arg3 : f32 to f64
+    // CHECK:   cond_br %[[COND]],
+    // CHECK:           ^[[AFTER:.*]](%[[VAL1]], %[[VAL2]] : i64, f64),
+    // CHECK:           ^[[CONT:.*]]
+    scf.condition(%0) %2, %3 : i64, f64
+  } do {
+  // CHECK:   ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
+  ^bb0(%arg2: i64, %arg3: f64):  // no predecessors
+    // CHECK:   br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
+    scf.yield %c0_i32, %cst : i32, f32
+  }
+  // CHECK:   ^bb3:
+  // CHECK:     return
+  return
+}
+
+// CHECK-LABEL: @nested_while_ops
+func @nested_while_ops(%arg0: f32) -> i64 {
+  // CHECK:       br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32)
+  %0 = scf.while(%outer = %arg0) : (f32) -> i64 {
+    // CHECK:   ^[[OUTER_BEFORE]](%{{.*}}: f32):
+    // CHECK:     %[[OUTER_COND:.*]] = "test.outer_before_pre"() : () -> i1
+    %cond = "test.outer_before_pre"() : () -> i1
+    // CHECK:     br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32)
+    %1 = scf.while(%inner = %outer) : (f32) -> i64 {
+      // CHECK: ^[[INNER_BEFORE_BEFORE]](%{{.*}}: f32):
+      // CHECK:   %[[INNER1:.*]]:2 = "test.inner_before"(%{{.*}}) : (f32) -> (i1, i64)
+      %2:2 = "test.inner_before"(%inner) : (f32) -> (i1, i64)
+      // CHECK:   cond_br %[[INNER1]]#0,
+      // CHECK:           ^[[INNER_BEFORE_AFTER:.*]](%[[INNER1]]#1 : i64),
+      // CHECK:           ^[[OUTER_BEFORE_LAST:.*]]
+      scf.condition(%2#0) %2#1 : i64
+    } do {
+      // CHECK: ^[[INNER_BEFORE_AFTER]](%{{.*}}: i64):
+    ^bb0(%arg1: i64):
+      // CHECK:   %[[INNER2:.*]] = "test.inner_after"(%{{.*}}) : (i64) -> f32
+      %3 = "test.inner_after"(%arg1) : (i64) -> f32
+      // CHECK:   br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32)
+      scf.yield %3 : f32
+    }
+    // CHECK:   ^[[OUTER_BEFORE_LAST]]:
+    // CHECK:     "test.outer_before_post"() : () -> ()
+    "test.outer_before_post"() : () -> ()
+    // CHECK:     cond_br %[[OUTER_COND]],
+    // CHECK:             ^[[OUTER_AFTER:.*]](%[[INNER1]]#1 : i64),
+    // CHECK:             ^[[CONTINUATION:.*]]
+    scf.condition(%cond) %1 : i64
+  } do {
+    // CHECK:   ^[[OUTER_AFTER]](%{{.*}}: i64):
+  ^bb2(%arg2: i64):
+    // CHECK:     "test.outer_after_pre"(%{{.*}}) : (i64) -> ()
+    "test.outer_after_pre"(%arg2) : (i64) -> ()
+    // CHECK:     br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64)
+    %4 = scf.while(%inner = %arg2) : (i64) -> f32 {
+      // CHECK: ^[[INNER_AFTER_BEFORE]](%{{.*}}: i64):
+      // CHECK:   %[[INNER3:.*]]:2 = "test.inner2_before"(%{{.*}}) : (i64) -> (i1, f32)
+      %5:2 = "test.inner2_before"(%inner) : (i64) -> (i1, f32)
+      // CHECK:   cond_br %[[INNER3]]#0,
+      // CHECK:           ^[[INNER_AFTER_AFTER:.*]](%[[INNER3]]#1 : f32),
+      // CHECK:           ^[[OUTER_AFTER_LAST:.*]]
+      scf.condition(%5#0) %5#1 : f32
+    } do {
+      // CHECK: ^[[INNER_AFTER_AFTER]](%{{.*}}: f32):
+    ^bb3(%arg3: f32):
+      // CHECK:   %{{.*}} = "test.inner2_after"(%{{.*}}) : (f32) -> i64
+      %6 = "test.inner2_after"(%arg3) : (f32) -> i64
+      // CHECK:   br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64)
+      scf.yield %6 : i64
+    }
+    // CHECK:   ^[[OUTER_AFTER_LAST]]:
+    // CHECK:     "test.outer_after_post"() : () -> ()
+    "test.outer_after_post"() : () -> ()
+    // CHECK:     br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32)
+    scf.yield %4 : f32
+  }
+  // CHECK:     ^[[CONTINUATION]]:
+  // CHECK:       return %{{.*}} : i64
+  return %0 : i64
+}
+


        


More information about the Mlir-commits mailing list