[Mlir-commits] [mlir] 2f4b303 - [mlir][Standard] Add canonicalization for collapsing pass through cond_br successors.
River Riddle
llvmlistbot at llvm.org
Thu Apr 23 04:51:02 PDT 2020
Author: River Riddle
Date: 2020-04-23T04:42:01-07:00
New Revision: 2f4b303d683c91bb5b3799688ecc689f59ed9de5
URL: https://github.com/llvm/llvm-project/commit/2f4b303d683c91bb5b3799688ecc689f59ed9de5
DIFF: https://github.com/llvm/llvm-project/commit/2f4b303d683c91bb5b3799688ecc689f59ed9de5.diff
LOG: [mlir][Standard] Add canonicalization for collapsing pass through cond_br successors.
This revision adds support for the following canonicalization:
```
cond_br %cond, ^bb1, ^bb2
^bb1
br ^bbN(...)
^bb2
br ^bbK(...)
cond_br %cond, ^bbN(...), ^bbK(...)
```
Differential Revision: https://reviews.llvm.org/D78681
Added:
mlir/test/Dialect/Standard/canonicalize-cf.mlir
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 0a96c9a5bf44..75bec0800628 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -862,11 +862,93 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
return failure();
}
};
-} // end anonymous namespace.
+
+/// cond_br %cond, ^bb1, ^bb2
+/// ^bb1
+/// br ^bbN(...)
+/// ^bb2
+/// br ^bbK(...)
+///
+/// cond_br %cond, ^bbN(...), ^bbK(...)
+///
+struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
+ ValueRange trueDestOperands = condbr.getTrueOperands();
+ ValueRange falseDestOperands = condbr.getFalseOperands();
+ SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
+
+ // Try to collapse one of the current successors.
+ LogicalResult collapsedTrue =
+ collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
+ LogicalResult collapsedFalse =
+ collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
+ if (failed(collapsedTrue) && failed(collapsedFalse))
+ return failure();
+
+ // Create a new branch with the collapsed successors.
+ rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
+ trueDest, trueDestOperands,
+ falseDest, falseDestOperands);
+ return success();
+ }
+
+ /// Given a successor, try to collapse it to a new destination if it only
+ /// contains a passthrough unconditional branch. If the successor is
+ /// collapsable, `successor` and `successorOperands` are updated to reference
+ /// the new destination and values. `argStorage` is an optional storage to use
+ /// if operands to the collapsed successor need to be remapped.
+ LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands,
+ SmallVectorImpl<Value> &argStorage) const {
+ // Check that the successor only contains a unconditional branch.
+ if (std::next(successor->begin()) != successor->end())
+ return failure();
+ // Check that the terminator is an unconditional branch.
+ BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
+ if (!successorBranch)
+ return failure();
+ // Check that the arguments are only used within the terminator.
+ for (BlockArgument arg : successor->getArguments()) {
+ for (Operation *user : arg.getUsers())
+ if (user != successorBranch)
+ return failure();
+ }
+ // Don't try to collapse branches to infinite loops.
+ Block *successorDest = successorBranch.getDest();
+ if (successorDest == successor)
+ return failure();
+
+ // Update the operands to the successor. If the branch parent has no
+ // arguments, we can use the branch operands directly.
+ OperandRange operands = successorBranch.getOperands();
+ if (successor->args_empty()) {
+ successor = successorBranch.getDest();
+ successorOperands = operands;
+ return success();
+ }
+
+ // Otherwise, we need to remap any argument operands.
+ for (Value operand : operands) {
+ BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
+ if (argOperand && argOperand.getOwner() == successor)
+ argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
+ else
+ argStorage.push_back(operand);
+ }
+ successor = successorBranch.getDest();
+ successorOperands = argStorage;
+ return success();
+ }
+};
+} // end anonymous namespace
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SimplifyConstCondBranchPred>(context);
+ results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch>(
+ context);
}
Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
new file mode 100644
index 000000000000..571c05505b48
--- /dev/null
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+
+// Test the folding of BranchOp.
+
+// CHECK-LABEL: func @br_folding(
+func @br_folding() -> i32 {
+ // CHECK-NEXT: %[[CST:.*]] = constant 0 : i32
+ // CHECK-NEXT: return %[[CST]] : i32
+ %c0_i32 = constant 0 : i32
+ br ^bb1(%c0_i32 : i32)
+^bb1(%x : i32):
+ return %x : i32
+}
+
+// Test the folding of CondBranchOp with a constant condition.
+
+// CHECK-LABEL: func @cond_br_folding(
+func @cond_br_folding(%cond : i1, %a : i32) {
+ // CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb1
+
+ %false_cond = constant 0 : i1
+ %true_cond = constant 1 : i1
+ cond_br %cond, ^bb1, ^bb2(%a : i32)
+
+^bb1:
+ cond_br %true_cond, ^bb3, ^bb2(%a : i32)
+
+^bb2(%x : i32):
+ cond_br %false_cond, ^bb2(%x : i32), ^bb3
+
+^bb3:
+ // CHECK: ^bb1:
+ // CHECK-NEXT: return
+
+ return
+}
+
+// Test the compound folding of BranchOp and CondBranchOp.
+
+// CHECK-LABEL: func @cond_br_and_br_folding(
+func @cond_br_and_br_folding(%a : i32) {
+ // CHECK-NEXT: return
+
+ %false_cond = constant 0 : i1
+ %true_cond = constant 1 : i1
+ cond_br %true_cond, ^bb2, ^bb1(%a : i32)
+
+^bb1(%x : i32):
+ cond_br %false_cond, ^bb1(%x : i32), ^bb2
+
+^bb2:
+ return
+}
+
+/// Test that pass-through successors of CondBranchOp get folded.
+
+// CHECK-LABEL: func @cond_br_pass_through(
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32
+func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
+ // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG0]], %[[ARG1]] : i32, i32), ^bb1(%[[ARG2]], %[[ARG2]] : i32, i32)
+
+ cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
+
+^bb1(%arg3: i32):
+ br ^bb2(%arg3, %arg1 : i32, i32)
+
+^bb2(%arg4: i32, %arg5: i32):
+ // CHECK: ^bb1(%[[RET0:.*]]: i32, %[[RET1:.*]]: i32):
+ // CHECK-NEXT: return %[[RET0]], %[[RET1]]
+
+ return %arg4, %arg5 : i32, i32
+}
+
+/// Test the failure modes of collapsing CondBranchOp pass-throughs successors.
+
+// CHECK-LABEL: func @cond_br_pass_through_fail(
+func @cond_br_pass_through_fail(%cond : i1) {
+ // CHECK: cond_br %{{.*}}, ^bb1, ^bb2
+
+ cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+ // CHECK: ^bb1:
+ // CHECK: "foo.op"
+ // CHECK: br ^bb2
+
+ // Successors can't be collapsed if they contain other operations.
+ "foo.op"() : () -> ()
+ br ^bb2
+
+^bb2:
+ return
+}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index e5e3af2724eb..2524d1c7cbad 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -506,52 +506,6 @@ func @const_fold_propagate() -> memref<?x?xf32> {
return %Av : memref<?x?xf32>
}
-// CHECK-LABEL: func @br_folding
-func @br_folding() -> i32 {
- // CHECK-NEXT: %[[CST:.*]] = constant 0 : i32
- // CHECK-NEXT: return %[[CST]] : i32
- %c0_i32 = constant 0 : i32
- br ^bb1(%c0_i32 : i32)
-^bb1(%x : i32):
- return %x : i32
-}
-
-// CHECK-LABEL: func @cond_br_folding
-func @cond_br_folding(%cond : i1, %a : i32) {
- %false_cond = constant 0 : i1
- %true_cond = constant 1 : i1
- cond_br %cond, ^bb1, ^bb2(%a : i32)
-
-^bb1:
- // CHECK: ^bb1:
- // CHECK-NEXT: br ^bb3
- cond_br %true_cond, ^bb3, ^bb2(%a : i32)
-
-^bb2(%x : i32):
- // CHECK: ^bb2
- // CHECK: br ^bb3
- cond_br %false_cond, ^bb2(%x : i32), ^bb3
-
-^bb3:
- return
-}
-
-// CHECK-LABEL: func @cond_br_and_br_folding
-func @cond_br_and_br_folding(%a : i32) {
- // Test the compound folding of conditional and unconditional branches.
- // CHECK-NEXT: return
-
- %false_cond = constant 0 : i1
- %true_cond = constant 1 : i1
- cond_br %true_cond, ^bb2, ^bb1(%a : i32)
-
-^bb1(%x : i32):
- cond_br %false_cond, ^bb1(%x : i32), ^bb2
-
-^bb2:
- return
-}
-
// CHECK-LABEL: func @indirect_call_folding
func @indirect_target() {
return
More information about the Mlir-commits
mailing list