[Mlir-commits] [mlir] 9a11c70 - [SCF] Handle lowering of Execute region to Standard CFG

William S. Moses llvmlistbot at llvm.org
Wed Jul 7 12:27:40 PDT 2021


Author: William S. Moses
Date: 2021-07-07T15:27:21-04:00
New Revision: 9a11c70c1856f4e801d0863c552c754f28110237

URL: https://github.com/llvm/llvm-project/commit/9a11c70c1856f4e801d0863c552c754f28110237
DIFF: https://github.com/llvm/llvm-project/commit/9a11c70c1856f4e801d0863c552c754f28110237.diff

LOG: [SCF] Handle lowering of Execute region to Standard CFG

Lower SCF.executeregionop to llvm by essentially inlining the region and replacing the return

Differential Revision: https://reviews.llvm.org/D105567

Added: 
    

Modified: 
    mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
    mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
index 6efba3fc816ce..4f451655e713d 100644
--- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
+++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
@@ -194,6 +194,13 @@ struct IfLowering : public OpRewritePattern<IfOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
+struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
+  using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExecuteRegionOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
   using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
 
@@ -400,6 +407,38 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
   return success();
 }
 
+LogicalResult
+ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
+                                       PatternRewriter &rewriter) const {
+  auto loc = op.getLoc();
+
+  auto *condBlock = rewriter.getInsertionBlock();
+  auto opPosition = rewriter.getInsertionPoint();
+  auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
+
+  auto &region = op.region();
+  rewriter.setInsertionPointToEnd(condBlock);
+  rewriter.create<BranchOp>(loc, &region.front());
+
+  for (Block &block : region) {
+    if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
+      ValueRange terminatorOperands = terminator->getOperands();
+      rewriter.setInsertionPointToEnd(&block);
+      rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
+      rewriter.eraseOp(terminator);
+    }
+  }
+
+  rewriter.inlineRegionBefore(region, remainingOpsBlock);
+
+  SmallVector<Value> vals;
+  for (auto arg : remainingOpsBlock->addArguments(op->getResultTypes())) {
+    vals.push_back(arg);
+  }
+  rewriter.replaceOp(op, vals);
+  return success();
+}
+
 LogicalResult
 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
                                   PatternRewriter &rewriter) const {
@@ -569,8 +608,8 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
 }
 
 void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
-      patterns.getContext());
+  patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
+               ExecuteRegionLowering>(patterns.getContext());
   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
 }
 
@@ -580,7 +619,8 @@ void SCFToStandardPass::runOnOperation() {
   // Configure conversion to lower out scf.for, scf.if, scf.parallel and
   // scf.while. Anything else is fine.
   ConversionTarget target(getContext());
-  target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp>();
+  target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
+                      scf::ExecuteRegionOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))

diff  --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
index 793d7f706f24f..91114a7ae5deb 100644
--- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
@@ -587,3 +587,36 @@ func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5
   // CHECK:   return
   return
 }
+
+// CHECK-LABEL: func @func_execute_region_elim_multi_yield
+func @func_execute_region_elim_multi_yield() {
+    "test.foo"() : () -> ()
+    %v = scf.execute_region -> i64 {
+      %c = "test.cmp"() : () -> i1
+      cond_br %c, ^bb2, ^bb3
+    ^bb2:
+      %x = "test.val1"() : () -> i64
+      scf.yield %x : i64
+    ^bb3:
+      %y = "test.val2"() : () -> i64
+      scf.yield %y : i64
+    }
+    "test.bar"(%v) : (i64) -> ()
+  return
+}
+
+// CHECK-NOT: execute_region
+// CHECK:     "test.foo"
+// CHECK:     br ^[[rentry:.+]]
+// CHECK:   ^[[rentry]]
+// CHECK:     %[[cmp:.+]] = "test.cmp"
+// CHECK:     cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
+// CHECK:   ^[[bb1]]:
+// CHECK:     %[[x:.+]] = "test.val1"
+// CHECK:     br ^[[bb3:.+]](%[[x]] : i64)
+// CHECK:   ^[[bb2]]:
+// CHECK:     %[[y:.+]] = "test.val2"
+// CHECK:     br ^[[bb3]](%[[y:.+]] : i64)
+// CHECK:   ^[[bb3]](%[[z:.+]]: i64):
+// CHECK:     "test.bar"(%[[z]])
+// CHECK:     return


        


More information about the Mlir-commits mailing list