[Mlir-commits] [mlir] 2f4b303 - [mlir][Standard] Add canonicalization for collapsing pass through cond_br successors.

River Riddle llvmlistbot at llvm.org
Thu Apr 23 04:51:02 PDT 2020


Author: River Riddle
Date: 2020-04-23T04:42:01-07:00
New Revision: 2f4b303d683c91bb5b3799688ecc689f59ed9de5

URL: https://github.com/llvm/llvm-project/commit/2f4b303d683c91bb5b3799688ecc689f59ed9de5
DIFF: https://github.com/llvm/llvm-project/commit/2f4b303d683c91bb5b3799688ecc689f59ed9de5.diff

LOG: [mlir][Standard] Add canonicalization for collapsing pass through cond_br successors.

This revision adds support for the following canonicalization:

```
   cond_br %cond, ^bb1, ^bb2
 ^bb1
   br ^bbN(...)
 ^bb2
   br ^bbK(...)

   cond_br %cond, ^bbN(...), ^bbK(...)
```

Differential Revision: https://reviews.llvm.org/D78681

Added: 
    mlir/test/Dialect/Standard/canonicalize-cf.mlir

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 0a96c9a5bf44..75bec0800628 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -862,11 +862,93 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
     return failure();
   }
 };
-} // end anonymous namespace.
+
+///   cond_br %cond, ^bb1, ^bb2
+/// ^bb1
+///   br ^bbN(...)
+/// ^bb2
+///   br ^bbK(...)
+///
+///   cond_br %cond, ^bbN(...), ^bbK(...)
+///
+struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
+  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CondBranchOp condbr,
+                                PatternRewriter &rewriter) const override {
+    Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
+    ValueRange trueDestOperands = condbr.getTrueOperands();
+    ValueRange falseDestOperands = condbr.getFalseOperands();
+    SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
+
+    // Try to collapse one of the current successors.
+    LogicalResult collapsedTrue =
+        collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
+    LogicalResult collapsedFalse =
+        collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
+    if (failed(collapsedTrue) && failed(collapsedFalse))
+      return failure();
+
+    // Create a new branch with the collapsed successors.
+    rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
+                                              trueDest, trueDestOperands,
+                                              falseDest, falseDestOperands);
+    return success();
+  }
+
+  /// Given a successor, try to collapse it to a new destination if it only
+  /// contains a passthrough unconditional branch. If the successor is
+  /// collapsable, `successor` and `successorOperands` are updated to reference
+  /// the new destination and values. `argStorage` is an optional storage to use
+  /// if operands to the collapsed successor need to be remapped.
+  LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands,
+                               SmallVectorImpl<Value> &argStorage) const {
+    // Check that the successor only contains a unconditional branch.
+    if (std::next(successor->begin()) != successor->end())
+      return failure();
+    // Check that the terminator is an unconditional branch.
+    BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
+    if (!successorBranch)
+      return failure();
+    // Check that the arguments are only used within the terminator.
+    for (BlockArgument arg : successor->getArguments()) {
+      for (Operation *user : arg.getUsers())
+        if (user != successorBranch)
+          return failure();
+    }
+    // Don't try to collapse branches to infinite loops.
+    Block *successorDest = successorBranch.getDest();
+    if (successorDest == successor)
+      return failure();
+
+    // Update the operands to the successor. If the branch parent has no
+    // arguments, we can use the branch operands directly.
+    OperandRange operands = successorBranch.getOperands();
+    if (successor->args_empty()) {
+      successor = successorBranch.getDest();
+      successorOperands = operands;
+      return success();
+    }
+
+    // Otherwise, we need to remap any argument operands.
+    for (Value operand : operands) {
+      BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
+      if (argOperand && argOperand.getOwner() == successor)
+        argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
+      else
+        argStorage.push_back(operand);
+    }
+    successor = successorBranch.getDest();
+    successorOperands = argStorage;
+    return success();
+  }
+};
+} // end anonymous namespace
 
 void CondBranchOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<SimplifyConstCondBranchPred>(context);
+  results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch>(
+      context);
 }
 
 Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {

diff  --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
new file mode 100644
index 000000000000..571c05505b48
--- /dev/null
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+
+// Test the folding of BranchOp.
+
+// CHECK-LABEL: func @br_folding(
+func @br_folding() -> i32 {
+  // CHECK-NEXT: %[[CST:.*]] = constant 0 : i32
+  // CHECK-NEXT: return %[[CST]] : i32
+  %c0_i32 = constant 0 : i32
+  br ^bb1(%c0_i32 : i32)
+^bb1(%x : i32):
+  return %x : i32
+}
+
+// Test the folding of CondBranchOp with a constant condition.
+
+// CHECK-LABEL: func @cond_br_folding(
+func @cond_br_folding(%cond : i1, %a : i32) {
+  // CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb1
+
+  %false_cond = constant 0 : i1
+  %true_cond = constant 1 : i1
+  cond_br %cond, ^bb1, ^bb2(%a : i32)
+
+^bb1:
+  cond_br %true_cond, ^bb3, ^bb2(%a : i32)
+
+^bb2(%x : i32):
+  cond_br %false_cond, ^bb2(%x : i32), ^bb3
+
+^bb3:
+  // CHECK: ^bb1:
+  // CHECK-NEXT: return
+
+  return
+}
+
+// Test the compound folding of BranchOp and CondBranchOp.
+
+// CHECK-LABEL: func @cond_br_and_br_folding(
+func @cond_br_and_br_folding(%a : i32) {
+  // CHECK-NEXT: return
+
+  %false_cond = constant 0 : i1
+  %true_cond = constant 1 : i1
+  cond_br %true_cond, ^bb2, ^bb1(%a : i32)
+
+^bb1(%x : i32):
+  cond_br %false_cond, ^bb1(%x : i32), ^bb2
+
+^bb2:
+  return
+}
+
+/// Test that pass-through successors of CondBranchOp get folded.
+
+// CHECK-LABEL: func @cond_br_pass_through(
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32
+func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
+  // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG0]], %[[ARG1]] : i32, i32), ^bb1(%[[ARG2]], %[[ARG2]] : i32, i32)
+
+  cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
+
+^bb1(%arg3: i32):
+  br ^bb2(%arg3, %arg1 : i32, i32)
+
+^bb2(%arg4: i32, %arg5: i32):
+  // CHECK: ^bb1(%[[RET0:.*]]: i32, %[[RET1:.*]]: i32):
+  // CHECK-NEXT: return %[[RET0]], %[[RET1]]
+
+  return %arg4, %arg5 : i32, i32
+}
+
+/// Test the failure modes of collapsing CondBranchOp pass-throughs successors.
+
+// CHECK-LABEL: func @cond_br_pass_through_fail(
+func @cond_br_pass_through_fail(%cond : i1) {
+  // CHECK: cond_br %{{.*}}, ^bb1, ^bb2
+
+  cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  // CHECK: ^bb1:
+  // CHECK: "foo.op"
+  // CHECK: br ^bb2
+
+  // Successors can't be collapsed if they contain other operations.
+  "foo.op"() : () -> ()
+  br ^bb2
+
+^bb2:
+  return
+}

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index e5e3af2724eb..2524d1c7cbad 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -506,52 +506,6 @@ func @const_fold_propagate() -> memref<?x?xf32> {
   return %Av : memref<?x?xf32>
 }
 
-// CHECK-LABEL: func @br_folding
-func @br_folding() -> i32 {
-  // CHECK-NEXT: %[[CST:.*]] = constant 0 : i32
-  // CHECK-NEXT: return %[[CST]] : i32
-  %c0_i32 = constant 0 : i32
-  br ^bb1(%c0_i32 : i32)
-^bb1(%x : i32):
-  return %x : i32
-}
-
-// CHECK-LABEL: func @cond_br_folding
-func @cond_br_folding(%cond : i1, %a : i32) {
-  %false_cond = constant 0 : i1
-  %true_cond = constant 1 : i1
-  cond_br %cond, ^bb1, ^bb2(%a : i32)
-
-^bb1:
-  // CHECK: ^bb1:
-  // CHECK-NEXT: br ^bb3
-  cond_br %true_cond, ^bb3, ^bb2(%a : i32)
-
-^bb2(%x : i32):
-  // CHECK: ^bb2
-  // CHECK: br ^bb3
-  cond_br %false_cond, ^bb2(%x : i32), ^bb3
-
-^bb3:
-  return
-}
-
-// CHECK-LABEL: func @cond_br_and_br_folding
-func @cond_br_and_br_folding(%a : i32) {
-  // Test the compound folding of conditional and unconditional branches.
-  // CHECK-NEXT: return
-
-  %false_cond = constant 0 : i1
-  %true_cond = constant 1 : i1
-  cond_br %true_cond, ^bb2, ^bb1(%a : i32)
-
-^bb1(%x : i32):
-  cond_br %false_cond, ^bb1(%x : i32), ^bb2
-
-^bb2:
-  return
-}
-
 // CHECK-LABEL: func @indirect_call_folding
 func @indirect_target() {
   return


        


More information about the Mlir-commits mailing list