[Mlir-commits] [mlir] a8feeee - [mlir] Add canonicalization for cond_br that feed into a cond_br on the same condition

River Riddle llvmlistbot at llvm.org
Sun Oct 18 13:57:00 PDT 2020


Author: River Riddle
Date: 2020-10-18T13:51:02-07:00
New Revision: a8feeee15fea0fcc936fa4a6f2eb891c90d69c3e

URL: https://github.com/llvm/llvm-project/commit/a8feeee15fea0fcc936fa4a6f2eb891c90d69c3e
DIFF: https://github.com/llvm/llvm-project/commit/a8feeee15fea0fcc936fa4a6f2eb891c90d69c3e.diff

LOG: [mlir] Add canonicalization for cond_br that feed into a cond_br on the same condition

```
   ...
   cond_br %cond, ^bb1(...), ^bb2(...)
 ...
 ^bb1: // has single predecessor
   ...
   cond_br %cond, ^bb3(...), ^bb4(...)
```

 ->

```
   ...
   cond_br %cond, ^bb1(...), ^bb2(...)
 ...
 ^bb1: // has single predecessor
   ...
   br ^bb3(...)
```

Differential Revision: https://reviews.llvm.org/D89604

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize-cf.mlir
    mlir/test/Transforms/canonicalize-block-merge.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 444b729ee751..d682ebcd1d8d 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1063,12 +1063,58 @@ struct SimplifyCondBranchIdenticalSuccessors
     return success();
   }
 };
+
+///   ...
+///   cond_br %cond, ^bb1(...), ^bb2(...)
+/// ...
+/// ^bb1: // has single predecessor
+///   ...
+///   cond_br %cond, ^bb3(...), ^bb4(...)
+///
+/// ->
+///
+///   ...
+///   cond_br %cond, ^bb1(...), ^bb2(...)
+/// ...
+/// ^bb1: // has single predecessor
+///   ...
+///   br ^bb3(...)
+///
+struct SimplifyCondBranchFromCondBranchOnSameCondition
+    : public OpRewritePattern<CondBranchOp> {
+  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CondBranchOp condbr,
+                                PatternRewriter &rewriter) const override {
+    // Check that we have a single distinct predecessor.
+    Block *currentBlock = condbr.getOperation()->getBlock();
+    Block *predecessor = currentBlock->getSinglePredecessor();
+    if (!predecessor)
+      return failure();
+
+    // Check that the predecessor terminates with a conditional branch to this
+    // block and that it branches on the same condition.
+    auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
+    if (!predBranch || condbr.getCondition() != predBranch.getCondition())
+      return failure();
+
+    // Fold this branch to an unconditional branch.
+    if (currentBlock == predBranch.trueDest())
+      rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.trueDest(),
+                                            condbr.trueDestOperands());
+    else
+      rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.falseDest(),
+                                            condbr.falseDestOperands());
+    return success();
+  }
+};
 } // end anonymous namespace
 
 void CondBranchOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
-                 SimplifyCondBranchIdenticalSuccessors>(context);
+                 SimplifyCondBranchIdenticalSuccessors,
+                 SimplifyCondBranchFromCondBranchOnSameCondition>(context);
 }
 
 Optional<MutableOperandRange>

diff  --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
index 0cdf7fdc1471..5f18562b7ad5 100644
--- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -139,6 +139,27 @@ func @cond_br_pass_through_fail(%cond : i1) {
   return
 }
 
+/// Test folding conditional branches that are successors of conditional
+/// branches with the same condition.
+
+// CHECK-LABEL: func @cond_br_from_cond_br_with_same_condition
+func @cond_br_from_cond_br_with_same_condition(%cond : i1) {
+  // CHECK:   cond_br %{{.*}}, ^bb1, ^bb2
+  // CHECK: ^bb1:
+  // CHECK:   return
+
+  cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  cond_br %cond, ^bb3, ^bb2
+
+^bb2:
+  "foo.terminator"() : () -> ()
+
+^bb3:
+  return
+}
+
 // -----
 
 // Erase assertion if condition is known to be true at compile time.

diff  --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index 607f6cafb9de..277b295e99be 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -178,23 +178,23 @@ func @contains_regions(%cond : i1) {
 // block is used in another.
 
 // CHECK-LABEL: func @mismatch_loop(
-// CHECK-SAME: %[[ARG:.*]]: i1
-func @mismatch_loop(%cond : i1) {
-  // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG]] : i1), ^bb2
+// CHECK-SAME: %[[ARG:.*]]: i1, %[[ARG2:.*]]: i1
+func @mismatch_loop(%cond : i1, %cond2 : i1) {
+  // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG2]] : i1), ^bb2
 
   cond_br %cond, ^bb2, ^bb3
 
 ^bb1:
-  // CHECK: ^bb1(%[[ARG2:.*]]: i1):
+  // CHECK: ^bb1(%[[ARG3:.*]]: i1):
   // CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op"
-  // CHECK-NEXT: cond_br %[[ARG2]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2
+  // CHECK-NEXT: cond_br %[[ARG3]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2
 
   %ignored = "foo.op"() : () -> (i1)
-  cond_br %cond2, ^bb1, ^bb3
+  cond_br %cond3, ^bb1, ^bb3
 
 ^bb2:
-  %cond2 = "foo.op"() : () -> (i1)
-  cond_br %cond, ^bb1, ^bb3
+  %cond3 = "foo.op"() : () -> (i1)
+  cond_br %cond2, ^bb1, ^bb3
 
 ^bb3:
   // CHECK: ^bb2:


        


More information about the Mlir-commits mailing list