[Mlir-commits] [mlir] 2fafe7f - [mlir][Standard] Add support for canonicalizing branches to passthrough blocks

River Riddle llvmlistbot at llvm.org
Thu Apr 23 04:51:07 PDT 2020


Author: River Riddle
Date: 2020-04-23T04:42:02-07:00
New Revision: 2fafe7ff591da8de6454c11c4756cc13a89b1c27

URL: https://github.com/llvm/llvm-project/commit/2fafe7ff591da8de6454c11c4756cc13a89b1c27
DIFF: https://github.com/llvm/llvm-project/commit/2fafe7ff591da8de6454c11c4756cc13a89b1c27.diff

LOG: [mlir][Standard] Add support for canonicalizing branches to passthrough blocks

This revision adds support for canonicalizing the following:

```
   br ^bb1
 ^bb1
   br ^bbN(...)

 br ^bbN(...)
```

Differential Revision: https://reviews.llvm.org/D78683

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize-cf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index f1fb0f90b57a..3294210d5218 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -566,6 +566,54 @@ static LogicalResult verify(AtomicYieldOp op) {
 // BranchOp
 //===----------------------------------------------------------------------===//
 
+/// Given a successor, try to collapse it to a new destination if it only
+/// contains a passthrough unconditional branch. If the successor is
+/// collapsable, `successor` and `successorOperands` are updated to reference
+/// the new destination and values. `argStorage` is an optional storage to use
+/// if operands to the collapsed successor need to be remapped.
+static LogicalResult collapseBranch(Block *&successor,
+                                    ValueRange &successorOperands,
+                                    SmallVectorImpl<Value> &argStorage) {
+  // Check that the successor only contains a unconditional branch.
+  if (std::next(successor->begin()) != successor->end())
+    return failure();
+  // Check that the terminator is an unconditional branch.
+  BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
+  if (!successorBranch)
+    return failure();
+  // Check that the arguments are only used within the terminator.
+  for (BlockArgument arg : successor->getArguments()) {
+    for (Operation *user : arg.getUsers())
+      if (user != successorBranch)
+        return failure();
+  }
+  // Don't try to collapse branches to infinite loops.
+  Block *successorDest = successorBranch.getDest();
+  if (successorDest == successor)
+    return failure();
+
+  // Update the operands to the successor. If the branch parent has no
+  // arguments, we can use the branch operands directly.
+  OperandRange operands = successorBranch.getOperands();
+  if (successor->args_empty()) {
+    successor = successorDest;
+    successorOperands = operands;
+    return success();
+  }
+
+  // Otherwise, we need to remap any argument operands.
+  for (Value operand : operands) {
+    BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
+    if (argOperand && argOperand.getOwner() == successor)
+      argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
+    else
+      argStorage.push_back(operand);
+  }
+  successor = successorDest;
+  successorOperands = argStorage;
+  return success();
+}
+
 namespace {
 /// Simplify a branch to a block that has a single predecessor. This effectively
 /// merges the two blocks.
@@ -586,6 +634,33 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
     return success();
   }
 };
+
+///   br ^bb1
+/// ^bb1
+///   br ^bbN(...)
+///
+///  -> br ^bbN(...)
+///
+struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> {
+  using OpRewritePattern<BranchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BranchOp op,
+                                PatternRewriter &rewriter) const override {
+    Block *dest = op.getDest();
+    ValueRange destOperands = op.getOperands();
+    SmallVector<Value, 4> destOperandStorage;
+
+    // Try to collapse the successor if it points somewhere other than this
+    // block.
+    if (dest == op.getOperation()->getBlock() ||
+        failed(collapseBranch(dest, destOperands, destOperandStorage)))
+      return failure();
+
+    // Create a new branch with the collapsed successor.
+    rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
+    return success();
+  }
+};
 } // end anonymous namespace.
 
 Block *BranchOp::getDest() { return getSuccessor(); }
@@ -598,7 +673,8 @@ void BranchOp::eraseOperand(unsigned index) {
 
 void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  results.insert<SimplifyBrToBlockWithSinglePred>(context);
+  results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(
+      context);
 }
 
 Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
@@ -889,53 +965,6 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
                                               falseDest, falseDestOperands);
     return success();
   }
-
-  /// Given a successor, try to collapse it to a new destination if it only
-  /// contains a passthrough unconditional branch. If the successor is
-  /// collapsable, `successor` and `successorOperands` are updated to reference
-  /// the new destination and values. `argStorage` is an optional storage to use
-  /// if operands to the collapsed successor need to be remapped.
-  LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands,
-                               SmallVectorImpl<Value> &argStorage) const {
-    // Check that the successor only contains a unconditional branch.
-    if (std::next(successor->begin()) != successor->end())
-      return failure();
-    // Check that the terminator is an unconditional branch.
-    BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
-    if (!successorBranch)
-      return failure();
-    // Check that the arguments are only used within the terminator.
-    for (BlockArgument arg : successor->getArguments()) {
-      for (Operation *user : arg.getUsers())
-        if (user != successorBranch)
-          return failure();
-    }
-    // Don't try to collapse branches to infinite loops.
-    Block *successorDest = successorBranch.getDest();
-    if (successorDest == successor)
-      return failure();
-
-    // Update the operands to the successor. If the branch parent has no
-    // arguments, we can use the branch operands directly.
-    OperandRange operands = successorBranch.getOperands();
-    if (successor->args_empty()) {
-      successor = successorBranch.getDest();
-      successorOperands = operands;
-      return success();
-    }
-
-    // Otherwise, we need to remap any argument operands.
-    for (Value operand : operands) {
-      BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
-      if (argOperand && argOperand.getOwner() == successor)
-        argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
-      else
-        argStorage.push_back(operand);
-    }
-    successor = successorBranch.getDest();
-    successorOperands = argStorage;
-    return success();
-  }
 };
 
 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)

diff  --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
index 8b7b3020fae0..71ee7f1fcfe0 100644
--- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -12,6 +12,26 @@ func @br_folding() -> i32 {
   return %x : i32
 }
 
+/// Test that pass-through successors of BranchOp get folded.
+
+// CHECK-LABEL: func @br_passthrough(
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @br_passthrough(%arg0 : i32, %arg1 : i32) -> (i32, i32) {
+  "foo.switch"() [^bb1, ^bb2, ^bb3] : () -> ()
+
+^bb1:
+  // CHECK: ^bb1:
+  // CHECK-NEXT: br ^bb3(%[[ARG0]], %[[ARG1]] : i32, i32)
+
+  br ^bb2(%arg0 : i32)
+
+^bb2(%arg2 : i32):
+  br ^bb3(%arg2, %arg1 : i32, i32)
+
+^bb3(%arg4 : i32, %arg5 : i32):
+  return %arg4, %arg5 : i32, i32
+}
+
 /// Test the folding of CondBranchOp with a constant condition.
 
 // CHECK-LABEL: func @cond_br_folding(
@@ -103,9 +123,9 @@ func @cond_br_and_br_folding(%a : i32) {
 
 /// Test that pass-through successors of CondBranchOp get folded.
 
-// CHECK-LABEL: func @cond_br_pass_through(
+// CHECK-LABEL: func @cond_br_passthrough(
 // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
-func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
+func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
   // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG2]]
   // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG1]], %[[ARG2]]
   // CHECK: return %[[RES]], %[[RES2]]


        


More information about the Mlir-commits mailing list