[Mlir-commits] [mlir] 5d6240b - [MLIR][SCF] Inline ExecuteRegion if parent can contain multiple blocks
William S. Moses
llvmlistbot at llvm.org
Mon Jun 28 10:10:21 PDT 2021
Author: William S. Moses
Date: 2021-06-28T13:09:22-04:00
New Revision: 5d6240b77e7e7199fcf0e89f6dd2f7eea3596a3c
URL: https://github.com/llvm/llvm-project/commit/5d6240b77e7e7199fcf0e89f6dd2f7eea3596a3c
DIFF: https://github.com/llvm/llvm-project/commit/5d6240b77e7e7199fcf0e89f6dd2f7eea3596a3c.diff
LOG: [MLIR][SCF] Inline ExecuteRegion if parent can contain multiple blocks
The executeregionop is used to allow multiple blocks within SCF constructs. If the container allows multiple blocks, inline the region
Differential Revision: https://reviews.llvm.org/D104960
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index c10441f59bd55..9f039b6fcda68 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -108,14 +108,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
let regions = (region AnyRegion:$region);
- // TODO: If the parent is a func like op (which would be the case if all other
- // ops are from the std dialect), the inliner logic could be readily used to
- // inline.
let hasCanonicalizer = 1;
- // TODO: can fold if it returns a constant.
- // TODO: Single block execute_region ops can be readily inlined irrespective
- // of which op is a parent. Add a fold for this.
let hasFolder = 0;
}
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 99d2386ced1b1..38760ca4050d3 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -143,23 +143,94 @@ static LogicalResult verify(ExecuteRegionOp op) {
//
// "test.foo"() : () -> ()
// %x = "test.val"() : () -> i64
-// "test.bar"(%v) : (i64) -> ()
+// "test.bar"(%x) : (i64) -> ()
//
struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const override {
- if (op.region().getBlocks().size() != 1)
+ if (!llvm::hasSingleElement(op.region()))
return failure();
replaceOpWithRegion(rewriter, op, op.region());
return success();
}
};
+// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
+// TODO generalize the conditions for operations which can be inlined into.
+// func @func_execute_region_elim() {
+// "test.foo"() : () -> ()
+// %v = scf.execute_region -> i64 {
+// %c = "test.cmp"() : () -> i1
+// cond_br %c, ^bb2, ^bb3
+// ^bb2:
+// %x = "test.val1"() : () -> i64
+// br ^bb4(%x : i64)
+// ^bb3:
+// %y = "test.val2"() : () -> i64
+// br ^bb4(%y : i64)
+// ^bb4(%z : i64):
+// scf.yield %z : i64
+// }
+// "test.bar"(%v) : (i64) -> ()
+// return
+// }
+//
+// becomes
+//
+// func @func_execute_region_elim() {
+// "test.foo"() : () -> ()
+// %c = "test.cmp"() : () -> i1
+// cond_br %c, ^bb1, ^bb2
+// ^bb1: // pred: ^bb0
+// %x = "test.val1"() : () -> i64
+// br ^bb3(%x : i64)
+// ^bb2: // pred: ^bb0
+// %y = "test.val2"() : () -> i64
+// br ^bb3(%y : i64)
+// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
+// "test.bar"(%z) : (i64) -> ()
+// return
+// }
+//
+struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExecuteRegionOp op,
+ PatternRewriter &rewriter) const override {
+ if (!isa<FuncOp, ExecuteRegionOp>(op->getParentOp()))
+ return failure();
+
+ Block *prevBlock = op->getBlock();
+ Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
+ rewriter.setInsertionPointToEnd(prevBlock);
+
+ rewriter.create<BranchOp>(op.getLoc(), &op.region().front());
+
+ for (Block &blk : op.region()) {
+ if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.create<BranchOp>(yieldOp.getLoc(), postBlock,
+ yieldOp.results());
+ rewriter.eraseOp(yieldOp);
+ }
+ }
+
+ rewriter.inlineRegionBefore(op.region(), postBlock);
+ SmallVector<Value> blockArgs;
+
+ for (auto res : op.getResults())
+ blockArgs.push_back(postBlock->addArgument(res.getType()));
+
+ rewriter.replaceOp(op, blockArgs);
+ return success();
+ }
+};
+
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleBlockExecuteInliner>(context);
+ results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8692f2d9705e0..d1789c6dfde52 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -948,3 +948,70 @@ func @execute_region_elim() {
// CHECK-NEXT: "test.bar"(%[[VAL]]) : (i64) -> ()
// CHECK-NEXT: }
+
+// -----
+
+// CHECK-LABEL: func @func_execute_region_elim
+func @func_execute_region_elim() {
+ "test.foo"() : () -> ()
+ %v = scf.execute_region -> i64 {
+ %c = "test.cmp"() : () -> i1
+ cond_br %c, ^bb2, ^bb3
+ ^bb2:
+ %x = "test.val1"() : () -> i64
+ br ^bb4(%x : i64)
+ ^bb3:
+ %y = "test.val2"() : () -> i64
+ br ^bb4(%y : i64)
+ ^bb4(%z : i64):
+ scf.yield %z : i64
+ }
+ "test.bar"(%v) : (i64) -> ()
+ return
+}
+
+// CHECK: "test.foo"
+// CHECK: %[[cmp:.+]] = "test.cmp"
+// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
+// CHECK: ^[[bb1]]: // pred: ^bb0
+// CHECK: %[[x:.+]] = "test.val1"
+// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
+// CHECK: ^[[bb2]]: // pred: ^bb0
+// CHECK: %[[y:.+]] = "test.val2"
+// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
+// CHECK: ^[[bb3]](%[[z:.+]]: i64):
+// CHECK: "test.bar"(%[[z]])
+// CHECK: return
+
+
+// -----
+
+// CHECK-LABEL: func @func_execute_region_elim2
+func @func_execute_region_elim2() {
+ "test.foo"() : () -> ()
+ %v = scf.execute_region -> i64 {
+ %c = "test.cmp"() : () -> i1
+ cond_br %c, ^bb2, ^bb3
+ ^bb2:
+ %x = "test.val1"() : () -> i64
+ scf.yield %x : i64
+ ^bb3:
+ %y = "test.val2"() : () -> i64
+ scf.yield %y : i64
+ }
+ "test.bar"(%v) : (i64) -> ()
+ return
+}
+
+// CHECK: "test.foo"
+// CHECK: %[[cmp:.+]] = "test.cmp"
+// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
+// CHECK: ^[[bb1]]: // pred: ^bb0
+// CHECK: %[[x:.+]] = "test.val1"
+// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
+// CHECK: ^[[bb2]]: // pred: ^bb0
+// CHECK: %[[y:.+]] = "test.val2"
+// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
+// CHECK: ^[[bb3]](%[[z:.+]]: i64):
+// CHECK: "test.bar"(%[[z]])
+// CHECK: return
More information about the Mlir-commits
mailing list