[Mlir-commits] [mlir] 9a11c70 - [SCF] Handle lowering of Execute region to Standard CFG
William S. Moses
llvmlistbot at llvm.org
Wed Jul 7 12:27:40 PDT 2021
Author: William S. Moses
Date: 2021-07-07T15:27:21-04:00
New Revision: 9a11c70c1856f4e801d0863c552c754f28110237
URL: https://github.com/llvm/llvm-project/commit/9a11c70c1856f4e801d0863c552c754f28110237
DIFF: https://github.com/llvm/llvm-project/commit/9a11c70c1856f4e801d0863c552c754f28110237.diff
LOG: [SCF] Handle lowering of Execute region to Standard CFG
Lower SCF.executeregionop to llvm by essentially inlining the region and replacing the return
Differential Revision: https://reviews.llvm.org/D105567
Added:
Modified:
mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
index 6efba3fc816ce..4f451655e713d 100644
--- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
+++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
@@ -194,6 +194,13 @@ struct IfLowering : public OpRewritePattern<IfOp> {
PatternRewriter &rewriter) const override;
};
+struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExecuteRegionOp op,
+ PatternRewriter &rewriter) const override;
+};
+
struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
@@ -400,6 +407,38 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
return success();
}
+LogicalResult
+ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
+ PatternRewriter &rewriter) const {
+ auto loc = op.getLoc();
+
+ auto *condBlock = rewriter.getInsertionBlock();
+ auto opPosition = rewriter.getInsertionPoint();
+ auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
+
+ auto ®ion = op.region();
+ rewriter.setInsertionPointToEnd(condBlock);
+ rewriter.create<BranchOp>(loc, ®ion.front());
+
+ for (Block &block : region) {
+ if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
+ ValueRange terminatorOperands = terminator->getOperands();
+ rewriter.setInsertionPointToEnd(&block);
+ rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
+ rewriter.eraseOp(terminator);
+ }
+ }
+
+ rewriter.inlineRegionBefore(region, remainingOpsBlock);
+
+ SmallVector<Value> vals;
+ for (auto arg : remainingOpsBlock->addArguments(op->getResultTypes())) {
+ vals.push_back(arg);
+ }
+ rewriter.replaceOp(op, vals);
+ return success();
+}
+
LogicalResult
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const {
@@ -569,8 +608,8 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
}
void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
- patterns.getContext());
+ patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
+ ExecuteRegionLowering>(patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
@@ -580,7 +619,8 @@ void SCFToStandardPass::runOnOperation() {
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
// scf.while. Anything else is fine.
ConversionTarget target(getContext());
- target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp>();
+ target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
+ scf::ExecuteRegionOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
index 793d7f706f24f..91114a7ae5deb 100644
--- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
@@ -587,3 +587,36 @@ func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5
// CHECK: return
return
}
+
+// CHECK-LABEL: func @func_execute_region_elim_multi_yield
+func @func_execute_region_elim_multi_yield() {
+ "test.foo"() : () -> ()
+ %v = scf.execute_region -> i64 {
+ %c = "test.cmp"() : () -> i1
+ cond_br %c, ^bb2, ^bb3
+ ^bb2:
+ %x = "test.val1"() : () -> i64
+ scf.yield %x : i64
+ ^bb3:
+ %y = "test.val2"() : () -> i64
+ scf.yield %y : i64
+ }
+ "test.bar"(%v) : (i64) -> ()
+ return
+}
+
+// CHECK-NOT: execute_region
+// CHECK: "test.foo"
+// CHECK: br ^[[rentry:.+]]
+// CHECK: ^[[rentry]]
+// CHECK: %[[cmp:.+]] = "test.cmp"
+// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
+// CHECK: ^[[bb1]]:
+// CHECK: %[[x:.+]] = "test.val1"
+// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
+// CHECK: ^[[bb2]]:
+// CHECK: %[[y:.+]] = "test.val2"
+// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
+// CHECK: ^[[bb3]](%[[z:.+]]: i64):
+// CHECK: "test.bar"(%[[z]])
+// CHECK: return
More information about the Mlir-commits
mailing list