[Mlir-commits] [mlir] 2fafe7f - [mlir][Standard] Add support for canonicalizing branches to passthrough blocks
River Riddle
llvmlistbot at llvm.org
Thu Apr 23 04:51:07 PDT 2020
Author: River Riddle
Date: 2020-04-23T04:42:02-07:00
New Revision: 2fafe7ff591da8de6454c11c4756cc13a89b1c27
URL: https://github.com/llvm/llvm-project/commit/2fafe7ff591da8de6454c11c4756cc13a89b1c27
DIFF: https://github.com/llvm/llvm-project/commit/2fafe7ff591da8de6454c11c4756cc13a89b1c27.diff
LOG: [mlir][Standard] Add support for canonicalizing branches to passthrough blocks
This revision adds support for canonicalizing the following:
```
br ^bb1
^bb1
br ^bbN(...)
br ^bbN(...)
```
Differential Revision: https://reviews.llvm.org/D78683
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize-cf.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index f1fb0f90b57a..3294210d5218 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -566,6 +566,54 @@ static LogicalResult verify(AtomicYieldOp op) {
// BranchOp
//===----------------------------------------------------------------------===//
+/// Given a successor, try to collapse it to a new destination if it only
+/// contains a passthrough unconditional branch. If the successor is
+/// collapsable, `successor` and `successorOperands` are updated to reference
+/// the new destination and values. `argStorage` is an optional storage to use
+/// if operands to the collapsed successor need to be remapped.
+static LogicalResult collapseBranch(Block *&successor,
+ ValueRange &successorOperands,
+ SmallVectorImpl<Value> &argStorage) {
+ // Check that the successor only contains a unconditional branch.
+ if (std::next(successor->begin()) != successor->end())
+ return failure();
+ // Check that the terminator is an unconditional branch.
+ BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
+ if (!successorBranch)
+ return failure();
+ // Check that the arguments are only used within the terminator.
+ for (BlockArgument arg : successor->getArguments()) {
+ for (Operation *user : arg.getUsers())
+ if (user != successorBranch)
+ return failure();
+ }
+ // Don't try to collapse branches to infinite loops.
+ Block *successorDest = successorBranch.getDest();
+ if (successorDest == successor)
+ return failure();
+
+ // Update the operands to the successor. If the branch parent has no
+ // arguments, we can use the branch operands directly.
+ OperandRange operands = successorBranch.getOperands();
+ if (successor->args_empty()) {
+ successor = successorDest;
+ successorOperands = operands;
+ return success();
+ }
+
+ // Otherwise, we need to remap any argument operands.
+ for (Value operand : operands) {
+ BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
+ if (argOperand && argOperand.getOwner() == successor)
+ argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
+ else
+ argStorage.push_back(operand);
+ }
+ successor = successorDest;
+ successorOperands = argStorage;
+ return success();
+}
+
namespace {
/// Simplify a branch to a block that has a single predecessor. This effectively
/// merges the two blocks.
@@ -586,6 +634,33 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
return success();
}
};
+
+/// br ^bb1
+/// ^bb1
+/// br ^bbN(...)
+///
+/// -> br ^bbN(...)
+///
+struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> {
+ using OpRewritePattern<BranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BranchOp op,
+ PatternRewriter &rewriter) const override {
+ Block *dest = op.getDest();
+ ValueRange destOperands = op.getOperands();
+ SmallVector<Value, 4> destOperandStorage;
+
+ // Try to collapse the successor if it points somewhere other than this
+ // block.
+ if (dest == op.getOperation()->getBlock() ||
+ failed(collapseBranch(dest, destOperands, destOperandStorage)))
+ return failure();
+
+ // Create a new branch with the collapsed successor.
+ rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
+ return success();
+ }
+};
} // end anonymous namespace.
Block *BranchOp::getDest() { return getSuccessor(); }
@@ -598,7 +673,8 @@ void BranchOp::eraseOperand(unsigned index) {
void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<SimplifyBrToBlockWithSinglePred>(context);
+ results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(
+ context);
}
Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
@@ -889,53 +965,6 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
falseDest, falseDestOperands);
return success();
}
-
- /// Given a successor, try to collapse it to a new destination if it only
- /// contains a passthrough unconditional branch. If the successor is
- /// collapsable, `successor` and `successorOperands` are updated to reference
- /// the new destination and values. `argStorage` is an optional storage to use
- /// if operands to the collapsed successor need to be remapped.
- LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands,
- SmallVectorImpl<Value> &argStorage) const {
- // Check that the successor only contains a unconditional branch.
- if (std::next(successor->begin()) != successor->end())
- return failure();
- // Check that the terminator is an unconditional branch.
- BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
- if (!successorBranch)
- return failure();
- // Check that the arguments are only used within the terminator.
- for (BlockArgument arg : successor->getArguments()) {
- for (Operation *user : arg.getUsers())
- if (user != successorBranch)
- return failure();
- }
- // Don't try to collapse branches to infinite loops.
- Block *successorDest = successorBranch.getDest();
- if (successorDest == successor)
- return failure();
-
- // Update the operands to the successor. If the branch parent has no
- // arguments, we can use the branch operands directly.
- OperandRange operands = successorBranch.getOperands();
- if (successor->args_empty()) {
- successor = successorBranch.getDest();
- successorOperands = operands;
- return success();
- }
-
- // Otherwise, we need to remap any argument operands.
- for (Value operand : operands) {
- BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
- if (argOperand && argOperand.getOwner() == successor)
- argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
- else
- argStorage.push_back(operand);
- }
- successor = successorBranch.getDest();
- successorOperands = argStorage;
- return success();
- }
};
/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
index 8b7b3020fae0..71ee7f1fcfe0 100644
--- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -12,6 +12,26 @@ func @br_folding() -> i32 {
return %x : i32
}
+/// Test that pass-through successors of BranchOp get folded.
+
+// CHECK-LABEL: func @br_passthrough(
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @br_passthrough(%arg0 : i32, %arg1 : i32) -> (i32, i32) {
+ "foo.switch"() [^bb1, ^bb2, ^bb3] : () -> ()
+
+^bb1:
+ // CHECK: ^bb1:
+ // CHECK-NEXT: br ^bb3(%[[ARG0]], %[[ARG1]] : i32, i32)
+
+ br ^bb2(%arg0 : i32)
+
+^bb2(%arg2 : i32):
+ br ^bb3(%arg2, %arg1 : i32, i32)
+
+^bb3(%arg4 : i32, %arg5 : i32):
+ return %arg4, %arg5 : i32, i32
+}
+
/// Test the folding of CondBranchOp with a constant condition.
// CHECK-LABEL: func @cond_br_folding(
@@ -103,9 +123,9 @@ func @cond_br_and_br_folding(%a : i32) {
/// Test that pass-through successors of CondBranchOp get folded.
-// CHECK-LABEL: func @cond_br_pass_through(
+// CHECK-LABEL: func @cond_br_passthrough(
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
-func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
+func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
// CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG2]]
// CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG1]], %[[ARG2]]
// CHECK: return %[[RES]], %[[RES2]]
More information about the Mlir-commits
mailing list