[Mlir-commits] [mlir] af331bc - [mlir][Standard] Add a canonicalization to simplify cond_br when the successors are identical
River Riddle
llvmlistbot at llvm.org
Thu Apr 23 04:51:05 PDT 2020
Author: River Riddle
Date: 2020-04-23T04:42:02-07:00
New Revision: af331bc52dc1a7c2bbfd09ddba1df4c7f3d321e7
URL: https://github.com/llvm/llvm-project/commit/af331bc52dc1a7c2bbfd09ddba1df4c7f3d321e7
DIFF: https://github.com/llvm/llvm-project/commit/af331bc52dc1a7c2bbfd09ddba1df4c7f3d321e7.diff
LOG: [mlir][Standard] Add a canonicalization to simplify cond_br when the successors are identical
This revision adds support for canonicalizing the following:
```
cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
br ^bb1(A, ..., N)
```
If the operands to the successor are different and the cond_br is the only predecessor, we emit selects for the branch operands.
```
cond_br %cond, ^bb1(A), ^bb1(B)
%select = select %cond, A, B
br ^bb1(%select)
```
Differential Revision: https://reviews.llvm.org/D78682
Added:
Modified:
llvm/include/llvm/ADT/STLExtras.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/Block.h
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/IR/Block.cpp
mlir/test/Dialect/Standard/canonicalize-cf.mlir
mlir/test/IR/invalid-ops.mlir
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index e3b9e405f69a..c10f4d233e50 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1123,10 +1123,13 @@ class indexed_accessor_range_base {
}
/// Compare this range with another.
- template <typename OtherT> bool operator==(const OtherT &other) {
+ template <typename OtherT> bool operator==(const OtherT &other) const {
return size() == std::distance(other.begin(), other.end()) &&
std::equal(begin(), end(), other.begin());
}
+ template <typename OtherT> bool operator!=(const OtherT &other) const {
+ return !(*this == other);
+ }
/// Return the size of this range.
size_t size() const { return count; }
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index bf0769dac652..39c9597d866a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1951,9 +1951,9 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
}];
let arguments = (ins BoolLike:$condition,
- SignlessIntegerOrFloatLike:$true_value,
- SignlessIntegerOrFloatLike:$false_value);
- let results = (outs SignlessIntegerOrFloatLike:$result);
+ AnyType:$true_value,
+ AnyType:$false_value);
+ let results = (outs AnyType:$result);
let verifier = ?;
let builders = [OpBuilder<
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 9818ff698092..12f82f84b52a 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -248,6 +248,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
/// destinations) is not considered to be a single predecessor.
Block *getSinglePredecessor();
+ /// If this block has a unique predecessor, i.e., all incoming edges originate
+ /// from one block, return it. Otherwise, return null.
+ Block *getUniquePredecessor();
+
// Indexed successor access.
unsigned getNumSuccessors();
Block *getSuccessor(unsigned i);
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 75bec0800628..f1fb0f90b57a 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -684,23 +684,15 @@ void CallIndirectOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
// Return the type of the same shape (scalar, vector or tensor) containing i1.
-static Type getCheckedI1SameShape(Type type) {
+static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(1, type.getContext());
- if (type.isSignlessIntOrIndexOrFloat())
- return i1Type;
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
if (type.isa<UnrankedTensorType>())
return UnrankedTensorType::get(i1Type);
if (auto vectorType = type.dyn_cast<VectorType>())
return VectorType::get(vectorType.getShape(), i1Type);
- return Type();
-}
-
-static Type getI1SameShape(Type type) {
- Type res = getCheckedI1SameShape(type);
- assert(res && "expected type with valid i1 shape");
- return res;
+ return i1Type;
}
//===----------------------------------------------------------------------===//
@@ -840,8 +832,10 @@ OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
namespace {
-/// cond_br true, ^bb1, ^bb2 -> br ^bb1
-/// cond_br false, ^bb1, ^bb2 -> br ^bb2
+/// cond_br true, ^bb1, ^bb2
+/// -> br ^bb1
+/// cond_br false, ^bb1, ^bb2
+/// -> br ^bb2
///
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
@@ -869,7 +863,7 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
/// ^bb2
/// br ^bbK(...)
///
-/// cond_br %cond, ^bbN(...), ^bbK(...)
+/// -> cond_br %cond, ^bbN(...), ^bbK(...)
///
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
@@ -943,12 +937,70 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
return success();
}
};
+
+/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
+/// -> br ^bb1(A, ..., N)
+///
+/// cond_br %cond, ^bb1(A), ^bb1(B)
+/// -> %select = select %cond, A, B
+/// br ^bb1(%select)
+///
+struct SimplifyCondBranchIdenticalSuccessors
+ : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ // Check that the true and false destinations are the same and have the same
+ // operands.
+ Block *trueDest = condbr.trueDest();
+ if (trueDest != condbr.falseDest())
+ return failure();
+
+ // If all of the operands match, no selects need to be generated.
+ OperandRange trueOperands = condbr.getTrueOperands();
+ OperandRange falseOperands = condbr.getFalseOperands();
+ if (trueOperands == falseOperands) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
+ return success();
+ }
+
+ // Otherwise, if the current block is the only predecessor insert selects
+ // for any mismatched branch operands.
+ if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
+ return failure();
+
+ // TODO: ATM Tensor/Vector SelectOp requires that the condition has the same
+ // shape as the operands. We should relax that to allow an i1 to signify
+ // that everything is selected.
+ auto doesntSupportsScalarI1 = [](Type type) {
+ return type.isa<TensorType>() || type.isa<VectorType>();
+ };
+ if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1))
+ return failure();
+
+ // Generate a select for any operands that
diff er between the two.
+ SmallVector<Value, 8> mergedOperands;
+ mergedOperands.reserve(trueOperands.size());
+ Value condition = condbr.getCondition();
+ for (auto it : llvm::zip(trueOperands, falseOperands)) {
+ if (std::get<0>(it) == std::get<1>(it))
+ mergedOperands.push_back(std::get<0>(it));
+ else
+ mergedOperands.push_back(rewriter.create<SelectOp>(
+ condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
+ }
+
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
+ return success();
+ }
+};
} // end anonymous namespace
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch>(
- context);
+ results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
+ SimplifyCondBranchIdenticalSuccessors>(context);
}
Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 392af93a3530..80131325adff 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -229,6 +229,21 @@ Block *Block::getSinglePredecessor() {
return it == pred_end() ? firstPred : nullptr;
}
+/// If this block has a unique predecessor, i.e., all incoming edges originate
+/// from one block, return it. Otherwise, return null.
+Block *Block::getUniquePredecessor() {
+ auto it = pred_begin(), e = pred_end();
+ if (it == e)
+ return nullptr;
+
+ // Check for any conflicting predecessors.
+ auto *firstPred = *it;
+ for (++it; it != e; ++it)
+ if (*it != firstPred)
+ return nullptr;
+ return firstPred;
+}
+
//===----------------------------------------------------------------------===//
// Other
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
index 571c05505b48..8b7b3020fae0 100644
--- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
-// Test the folding of BranchOp.
+/// Test the folding of BranchOp.
// CHECK-LABEL: func @br_folding(
func @br_folding() -> i32 {
@@ -12,11 +12,11 @@ func @br_folding() -> i32 {
return %x : i32
}
-// Test the folding of CondBranchOp with a constant condition.
+/// Test the folding of CondBranchOp with a constant condition.
// CHECK-LABEL: func @cond_br_folding(
func @cond_br_folding(%cond : i1, %a : i32) {
- // CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb1
+ // CHECK-NEXT: return
%false_cond = constant 0 : i1
%true_cond = constant 1 : i1
@@ -29,13 +29,62 @@ func @cond_br_folding(%cond : i1, %a : i32) {
cond_br %false_cond, ^bb2(%x : i32), ^bb3
^bb3:
- // CHECK: ^bb1:
+ return
+}
+
+/// Test the folding of CondBranchOp when the successors are identical.
+
+// CHECK-LABEL: func @cond_br_same_successor(
+func @cond_br_same_successor(%cond : i1, %a : i32) {
// CHECK-NEXT: return
+ cond_br %cond, ^bb1(%a : i32), ^bb1(%a : i32)
+
+^bb1(%result : i32):
return
}
-// Test the compound folding of BranchOp and CondBranchOp.
+/// Test the folding of CondBranchOp when the successors are identical, but the
+/// arguments are
diff erent.
+
+// CHECK-LABEL: func @cond_br_same_successor_insert_select(
+// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 {
+ // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]]
+ // CHECK: return %[[RES]]
+
+ cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32)
+
+^bb1(%result : i32):
+ return %result : i32
+}
+
+/// Check that we don't generate a select if the type requires a splat.
+/// TODO: SelectOp should allow for matching a vector/tensor with i1.
+
+// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor(
+func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>,
+ %b : tensor<2xi32>) -> tensor<2xi32>{
+ // CHECK: cond_br
+
+ cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>)
+
+^bb1(%result : tensor<2xi32>):
+ return %result : tensor<2xi32>
+}
+
+// CHECK-LABEL: func @cond_br_same_successor_no_select_vector(
+func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>,
+ %b : vector<2xi32>) -> vector<2xi32> {
+ // CHECK: cond_br
+
+ cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>)
+
+^bb1(%result : vector<2xi32>):
+ return %result : vector<2xi32>
+}
+
+/// Test the compound folding of BranchOp and CondBranchOp.
// CHECK-LABEL: func @cond_br_and_br_folding(
func @cond_br_and_br_folding(%a : i32) {
@@ -55,9 +104,11 @@ func @cond_br_and_br_folding(%a : i32) {
/// Test that pass-through successors of CondBranchOp get folded.
// CHECK-LABEL: func @cond_br_pass_through(
-// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
- // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG0]], %[[ARG1]] : i32, i32), ^bb1(%[[ARG2]], %[[ARG2]] : i32, i32)
+ // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG2]]
+ // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG1]], %[[ARG2]]
+ // CHECK: return %[[RES]], %[[RES2]]
cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
@@ -65,9 +116,6 @@ func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) ->
br ^bb2(%arg3, %arg1 : i32, i32)
^bb2(%arg4: i32, %arg5: i32):
- // CHECK: ^bb1(%[[RET0:.*]]: i32, %[[RET1:.*]]: i32):
- // CHECK-NEXT: return %[[RET0]], %[[RET1]]
-
return %arg4, %arg5 : i32, i32
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index fe2556d0a9d9..c7b290517f02 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -297,12 +297,6 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
// -----
-func @invalid_select_shape(%cond : i1, %idx : () -> ()) {
- // expected-error at +1 {{'result' must be signless-integer-like or floating-point-like, but got '() -> ()'}}
- %sel = select %cond, %idx, %idx : () -> ()
-
-// -----
-
func @invalid_cmp_shape(%idx : () -> ()) {
// expected-error at +1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
%cmp = cmpi "eq", %idx, %idx : () -> ()
More information about the Mlir-commits
mailing list