[Mlir-commits] [mlir] af331bc - [mlir][Standard] Add a canonicalization to simplify cond_br when the successors are identical

River Riddle llvmlistbot at llvm.org
Thu Apr 23 04:51:05 PDT 2020


Author: River Riddle
Date: 2020-04-23T04:42:02-07:00
New Revision: af331bc52dc1a7c2bbfd09ddba1df4c7f3d321e7

URL: https://github.com/llvm/llvm-project/commit/af331bc52dc1a7c2bbfd09ddba1df4c7f3d321e7
DIFF: https://github.com/llvm/llvm-project/commit/af331bc52dc1a7c2bbfd09ddba1df4c7f3d321e7.diff

LOG: [mlir][Standard] Add a canonicalization to simplify cond_br when the successors are identical

This revision adds support for canonicalizing the following:

```
cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)

br ^bb1(A, ..., N)
```

 If the operands to the successor are different and the cond_br is the only predecessor, we emit selects for the branch operands.

```
cond_br %cond, ^bb1(A), ^bb1(B)

%select = select %cond, A, B
br ^bb1(%select)
```

Differential Revision: https://reviews.llvm.org/D78682

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/Block.h
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/IR/Block.cpp
    mlir/test/Dialect/Standard/canonicalize-cf.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index e3b9e405f69a..c10f4d233e50 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1123,10 +1123,13 @@ class indexed_accessor_range_base {
   }
 
   /// Compare this range with another.
-  template <typename OtherT> bool operator==(const OtherT &other) {
+  template <typename OtherT> bool operator==(const OtherT &other) const {
     return size() == std::distance(other.begin(), other.end()) &&
            std::equal(begin(), end(), other.begin());
   }
+  template <typename OtherT> bool operator!=(const OtherT &other) const {
+    return !(*this == other);
+  }
 
   /// Return the size of this range.
   size_t size() const { return count; }

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index bf0769dac652..39c9597d866a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1951,9 +1951,9 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
   }];
 
   let arguments = (ins BoolLike:$condition,
-                       SignlessIntegerOrFloatLike:$true_value,
-                       SignlessIntegerOrFloatLike:$false_value);
-  let results = (outs SignlessIntegerOrFloatLike:$result);
+                       AnyType:$true_value,
+                       AnyType:$false_value);
+  let results = (outs AnyType:$result);
   let verifier = ?;
 
   let builders = [OpBuilder<

diff  --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 9818ff698092..12f82f84b52a 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -248,6 +248,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   /// destinations) is not considered to be a single predecessor.
   Block *getSinglePredecessor();
 
+  /// If this block has a unique predecessor, i.e., all incoming edges originate
+  /// from one block, return it. Otherwise, return null.
+  Block *getUniquePredecessor();
+
   // Indexed successor access.
   unsigned getNumSuccessors();
   Block *getSuccessor(unsigned i);

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 75bec0800628..f1fb0f90b57a 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -684,23 +684,15 @@ void CallIndirectOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 
 // Return the type of the same shape (scalar, vector or tensor) containing i1.
-static Type getCheckedI1SameShape(Type type) {
+static Type getI1SameShape(Type type) {
   auto i1Type = IntegerType::get(1, type.getContext());
-  if (type.isSignlessIntOrIndexOrFloat())
-    return i1Type;
   if (auto tensorType = type.dyn_cast<RankedTensorType>())
     return RankedTensorType::get(tensorType.getShape(), i1Type);
   if (type.isa<UnrankedTensorType>())
     return UnrankedTensorType::get(i1Type);
   if (auto vectorType = type.dyn_cast<VectorType>())
     return VectorType::get(vectorType.getShape(), i1Type);
-  return Type();
-}
-
-static Type getI1SameShape(Type type) {
-  Type res = getCheckedI1SameShape(type);
-  assert(res && "expected type with valid i1 shape");
-  return res;
+  return i1Type;
 }
 
 //===----------------------------------------------------------------------===//
@@ -840,8 +832,10 @@ OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// cond_br true, ^bb1, ^bb2 -> br ^bb1
-/// cond_br false, ^bb1, ^bb2 -> br ^bb2
+/// cond_br true, ^bb1, ^bb2
+///  -> br ^bb1
+/// cond_br false, ^bb1, ^bb2
+///  -> br ^bb2
 ///
 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
@@ -869,7 +863,7 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
 /// ^bb2
 ///   br ^bbK(...)
 ///
-///   cond_br %cond, ^bbN(...), ^bbK(...)
+///  -> cond_br %cond, ^bbN(...), ^bbK(...)
 ///
 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
@@ -943,12 +937,70 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
     return success();
   }
 };
+
+/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
+///  -> br ^bb1(A, ..., N)
+///
+/// cond_br %cond, ^bb1(A), ^bb1(B)
+///  -> %select = select %cond, A, B
+///     br ^bb1(%select)
+///
+struct SimplifyCondBranchIdenticalSuccessors
+    : public OpRewritePattern<CondBranchOp> {
+  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CondBranchOp condbr,
+                                PatternRewriter &rewriter) const override {
+    // Check that the true and false destinations are the same and have the same
+    // operands.
+    Block *trueDest = condbr.trueDest();
+    if (trueDest != condbr.falseDest())
+      return failure();
+
+    // If all of the operands match, no selects need to be generated.
+    OperandRange trueOperands = condbr.getTrueOperands();
+    OperandRange falseOperands = condbr.getFalseOperands();
+    if (trueOperands == falseOperands) {
+      rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
+      return success();
+    }
+
+    // Otherwise, if the current block is the only predecessor insert selects
+    // for any mismatched branch operands.
+    if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
+      return failure();
+
+    // TODO: ATM Tensor/Vector SelectOp requires that the condition has the same
+    // shape as the operands. We should relax that to allow an i1 to signify
+    // that everything is selected.
+    auto doesntSupportsScalarI1 = [](Type type) {
+      return type.isa<TensorType>() || type.isa<VectorType>();
+    };
+    if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1))
+      return failure();
+
+    // Generate a select for any operands that 
diff er between the two.
+    SmallVector<Value, 8> mergedOperands;
+    mergedOperands.reserve(trueOperands.size());
+    Value condition = condbr.getCondition();
+    for (auto it : llvm::zip(trueOperands, falseOperands)) {
+      if (std::get<0>(it) == std::get<1>(it))
+        mergedOperands.push_back(std::get<0>(it));
+      else
+        mergedOperands.push_back(rewriter.create<SelectOp>(
+            condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
+    }
+
+    rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
+    return success();
+  }
+};
 } // end anonymous namespace
 
 void CondBranchOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch>(
-      context);
+  results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
+                 SimplifyCondBranchIdenticalSuccessors>(context);
 }
 
 Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {

diff  --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 392af93a3530..80131325adff 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -229,6 +229,21 @@ Block *Block::getSinglePredecessor() {
   return it == pred_end() ? firstPred : nullptr;
 }
 
+/// If this block has a unique predecessor, i.e., all incoming edges originate
+/// from one block, return it. Otherwise, return null.
+Block *Block::getUniquePredecessor() {
+  auto it = pred_begin(), e = pred_end();
+  if (it == e)
+    return nullptr;
+
+  // Check for any conflicting predecessors.
+  auto *firstPred = *it;
+  for (++it; it != e; ++it)
+    if (*it != firstPred)
+      return nullptr;
+  return firstPred;
+}
+
 //===----------------------------------------------------------------------===//
 // Other
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
index 571c05505b48..8b7b3020fae0 100644
--- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
 
-// Test the folding of BranchOp.
+/// Test the folding of BranchOp.
 
 // CHECK-LABEL: func @br_folding(
 func @br_folding() -> i32 {
@@ -12,11 +12,11 @@ func @br_folding() -> i32 {
   return %x : i32
 }
 
-// Test the folding of CondBranchOp with a constant condition.
+/// Test the folding of CondBranchOp with a constant condition.
 
 // CHECK-LABEL: func @cond_br_folding(
 func @cond_br_folding(%cond : i1, %a : i32) {
-  // CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb1
+  // CHECK-NEXT: return
 
   %false_cond = constant 0 : i1
   %true_cond = constant 1 : i1
@@ -29,13 +29,62 @@ func @cond_br_folding(%cond : i1, %a : i32) {
   cond_br %false_cond, ^bb2(%x : i32), ^bb3
 
 ^bb3:
-  // CHECK: ^bb1:
+  return
+}
+
+/// Test the folding of CondBranchOp when the successors are identical.
+
+// CHECK-LABEL: func @cond_br_same_successor(
+func @cond_br_same_successor(%cond : i1, %a : i32) {
   // CHECK-NEXT: return
 
+  cond_br %cond, ^bb1(%a : i32), ^bb1(%a : i32)
+
+^bb1(%result : i32):
   return
 }
 
-// Test the compound folding of BranchOp and CondBranchOp.
+/// Test the folding of CondBranchOp when the successors are identical, but the
+/// arguments are 
diff erent.
+
+// CHECK-LABEL: func @cond_br_same_successor_insert_select(
+// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 {
+  // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]]
+  // CHECK: return %[[RES]]
+
+  cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32)
+
+^bb1(%result : i32):
+  return %result : i32
+}
+
+/// Check that we don't generate a select if the type requires a splat.
+/// TODO: SelectOp should allow for matching a vector/tensor with i1.
+
+// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor(
+func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>,
+                                              %b : tensor<2xi32>) -> tensor<2xi32>{
+  // CHECK: cond_br
+
+  cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>)
+
+^bb1(%result : tensor<2xi32>):
+  return %result : tensor<2xi32>
+}
+
+// CHECK-LABEL: func @cond_br_same_successor_no_select_vector(
+func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>,
+                                              %b : vector<2xi32>) -> vector<2xi32> {
+  // CHECK: cond_br
+
+  cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>)
+
+^bb1(%result : vector<2xi32>):
+  return %result : vector<2xi32>
+}
+
+/// Test the compound folding of BranchOp and CondBranchOp.
 
 // CHECK-LABEL: func @cond_br_and_br_folding(
 func @cond_br_and_br_folding(%a : i32) {
@@ -55,9 +104,11 @@ func @cond_br_and_br_folding(%a : i32) {
 /// Test that pass-through successors of CondBranchOp get folded.
 
 // CHECK-LABEL: func @cond_br_pass_through(
-// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
 func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
-  // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG0]], %[[ARG1]] : i32, i32), ^bb1(%[[ARG2]], %[[ARG2]] : i32, i32)
+  // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG2]]
+  // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG1]], %[[ARG2]]
+  // CHECK: return %[[RES]], %[[RES2]]
 
   cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
 
@@ -65,9 +116,6 @@ func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) ->
   br ^bb2(%arg3, %arg1 : i32, i32)
 
 ^bb2(%arg4: i32, %arg5: i32):
-  // CHECK: ^bb1(%[[RET0:.*]]: i32, %[[RET1:.*]]: i32):
-  // CHECK-NEXT: return %[[RET0]], %[[RET1]]
-
   return %arg4, %arg5 : i32, i32
 }
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index fe2556d0a9d9..c7b290517f02 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -297,12 +297,6 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
 
 // -----
 
-func @invalid_select_shape(%cond : i1, %idx : () -> ()) {
-  // expected-error at +1 {{'result' must be signless-integer-like or floating-point-like, but got '() -> ()'}}
-  %sel = select %cond, %idx, %idx : () -> ()
-
-// -----
-
 func @invalid_cmp_shape(%idx : () -> ()) {
   // expected-error at +1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
   %cmp = cmpi "eq", %idx, %idx : () -> ()


        


More information about the Mlir-commits mailing list