[Mlir-commits] [mlir] a8feeee - [mlir] Add canonicalization for cond_br that feed into a cond_br on the same condition
River Riddle
llvmlistbot at llvm.org
Sun Oct 18 13:57:00 PDT 2020
Author: River Riddle
Date: 2020-10-18T13:51:02-07:00
New Revision: a8feeee15fea0fcc936fa4a6f2eb891c90d69c3e
URL: https://github.com/llvm/llvm-project/commit/a8feeee15fea0fcc936fa4a6f2eb891c90d69c3e
DIFF: https://github.com/llvm/llvm-project/commit/a8feeee15fea0fcc936fa4a6f2eb891c90d69c3e.diff
LOG: [mlir] Add canonicalization for cond_br that feed into a cond_br on the same condition
```
...
cond_br %cond, ^bb1(...), ^bb2(...)
...
^bb1: // has single predecessor
...
cond_br %cond, ^bb3(...), ^bb4(...)
```
->
```
...
cond_br %cond, ^bb1(...), ^bb2(...)
...
^bb1: // has single predecessor
...
br ^bb3(...)
```
Differential Revision: https://reviews.llvm.org/D89604
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize-cf.mlir
mlir/test/Transforms/canonicalize-block-merge.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 444b729ee751..d682ebcd1d8d 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1063,12 +1063,58 @@ struct SimplifyCondBranchIdenticalSuccessors
return success();
}
};
+
+/// ...
+/// cond_br %cond, ^bb1(...), ^bb2(...)
+/// ...
+/// ^bb1: // has single predecessor
+/// ...
+/// cond_br %cond, ^bb3(...), ^bb4(...)
+///
+/// ->
+///
+/// ...
+/// cond_br %cond, ^bb1(...), ^bb2(...)
+/// ...
+/// ^bb1: // has single predecessor
+/// ...
+/// br ^bb3(...)
+///
+struct SimplifyCondBranchFromCondBranchOnSameCondition
+ : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ // Check that we have a single distinct predecessor.
+ Block *currentBlock = condbr.getOperation()->getBlock();
+ Block *predecessor = currentBlock->getSinglePredecessor();
+ if (!predecessor)
+ return failure();
+
+ // Check that the predecessor terminates with a conditional branch to this
+ // block and that it branches on the same condition.
+ auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
+ if (!predBranch || condbr.getCondition() != predBranch.getCondition())
+ return failure();
+
+ // Fold this branch to an unconditional branch.
+ if (currentBlock == predBranch.trueDest())
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.trueDest(),
+ condbr.trueDestOperands());
+ else
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.falseDest(),
+ condbr.falseDestOperands());
+ return success();
+ }
+};
} // end anonymous namespace
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
- SimplifyCondBranchIdenticalSuccessors>(context);
+ SimplifyCondBranchIdenticalSuccessors,
+ SimplifyCondBranchFromCondBranchOnSameCondition>(context);
}
Optional<MutableOperandRange>
diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
index 0cdf7fdc1471..5f18562b7ad5 100644
--- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -139,6 +139,27 @@ func @cond_br_pass_through_fail(%cond : i1) {
return
}
+/// Test folding conditional branches that are successors of conditional
+/// branches with the same condition.
+
+// CHECK-LABEL: func @cond_br_from_cond_br_with_same_condition
+func @cond_br_from_cond_br_with_same_condition(%cond : i1) {
+ // CHECK: cond_br %{{.*}}, ^bb1, ^bb2
+ // CHECK: ^bb1:
+ // CHECK: return
+
+ cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+ cond_br %cond, ^bb3, ^bb2
+
+^bb2:
+ "foo.terminator"() : () -> ()
+
+^bb3:
+ return
+}
+
// -----
// Erase assertion if condition is known to be true at compile time.
diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index 607f6cafb9de..277b295e99be 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -178,23 +178,23 @@ func @contains_regions(%cond : i1) {
// block is used in another.
// CHECK-LABEL: func @mismatch_loop(
-// CHECK-SAME: %[[ARG:.*]]: i1
-func @mismatch_loop(%cond : i1) {
- // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG]] : i1), ^bb2
+// CHECK-SAME: %[[ARG:.*]]: i1, %[[ARG2:.*]]: i1
+func @mismatch_loop(%cond : i1, %cond2 : i1) {
+ // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG2]] : i1), ^bb2
cond_br %cond, ^bb2, ^bb3
^bb1:
- // CHECK: ^bb1(%[[ARG2:.*]]: i1):
+ // CHECK: ^bb1(%[[ARG3:.*]]: i1):
// CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op"
- // CHECK-NEXT: cond_br %[[ARG2]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2
+ // CHECK-NEXT: cond_br %[[ARG3]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2
%ignored = "foo.op"() : () -> (i1)
- cond_br %cond2, ^bb1, ^bb3
+ cond_br %cond3, ^bb1, ^bb3
^bb2:
- %cond2 = "foo.op"() : () -> (i1)
- cond_br %cond, ^bb1, ^bb3
+ %cond3 = "foo.op"() : () -> (i1)
+ cond_br %cond2, ^bb1, ^bb3
^bb3:
// CHECK: ^bb2:
More information about the Mlir-commits
mailing list