[Mlir-commits] [mlir] 97567bd - [MLIR][SCF] Canonicalize while statement whose cmp condition is recomputed in the after region
William S. Moses
llvmlistbot at llvm.org
Tue Jan 11 15:34:07 PST 2022
Author: William S. Moses
Date: 2022-01-11T18:34:04-05:00
New Revision: 97567bde5baaf87c0ac71cfc114dad56442b85d3
URL: https://github.com/llvm/llvm-project/commit/97567bde5baaf87c0ac71cfc114dad56442b85d3
DIFF: https://github.com/llvm/llvm-project/commit/97567bde5baaf87c0ac71cfc114dad56442b85d3.diff
LOG: [MLIR][SCF] Canonicalize while statement whose cmp condition is recomputed in the after region
Given a while loop whose condition is given by a cmp, don't recomputed the comparison (or its inverse) in the after region, instead use a constant since the original condition must be true if we branched to the after region.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D117047
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
index 31d6239388454..65d819f18df28 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
@@ -121,6 +121,8 @@ Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs);
+
+arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
} // namespace arith
} // namespace mlir
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 69acc19fb9e4f..188d6e56543a4 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -40,7 +40,7 @@ static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
}
/// Invert an integer comparison predicate.
-static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
+arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
case arith::CmpIPredicate::eq:
return arith::CmpIPredicate::ne;
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 3d6d2052f7fe1..6b9fe80bbccdd 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -2443,11 +2443,76 @@ struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
return success();
}
};
+
+/// Replace operations equivalent to the condition in the do block with true,
+/// since otherwise the block would not be evaluated.
+///
+/// scf.while (..) : (i32, ...) -> ... {
+/// %z = ... : i32
+/// %condition = cmpi pred %z, %a
+/// scf.condition(%condition) %z : i32, ...
+/// } do {
+/// ^bb0(%arg0: i32, ...):
+/// %condition2 = cmpi pred %arg0, %a
+/// use(%condition2)
+/// ...
+///
+/// becomes
+/// scf.while (..) : (i32, ...) -> ... {
+/// %z = ... : i32
+/// %condition = cmpi pred %z, %a
+/// scf.condition(%condition) %z : i32, ...
+/// } do {
+/// ^bb0(%arg0: i32, ...):
+/// use(%true)
+/// ...
+struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
+ using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::WhileOp op,
+ PatternRewriter &rewriter) const override {
+ using namespace scf;
+ auto cond = op.getConditionOp();
+ auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
+ if (!cmp)
+ return failure();
+ bool changed = false;
+ for (auto tup :
+ llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
+ for (size_t opIdx = 0; opIdx < 2; opIdx++) {
+ if (std::get<0>(tup) != cmp.getOperand(opIdx))
+ continue;
+ for (OpOperand &u :
+ llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
+ auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
+ if (!cmp2)
+ continue;
+ // For a binary operator 1-opIdx gets the other side.
+ if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
+ continue;
+ bool samePredicate;
+ if (cmp2.getPredicate() == cmp.getPredicate())
+ samePredicate = true;
+ else if (cmp2.getPredicate() ==
+ arith::invertPredicate(cmp.getPredicate()))
+ samePredicate = false;
+ else
+ continue;
+
+ rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
+ 1);
+ changed = true;
+ }
+ }
+ }
+ return success(changed);
+ }
+};
} // namespace
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<WhileConditionTruth, WhileUnusedResult>(context);
+ results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 7f424d892b764..2b3a31e833a3d 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -872,6 +872,60 @@ func @while_unused_result() -> i32 {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]] : i32
+// CHECK-LABEL: @while_cmp_lhs
+func @while_cmp_lhs(%arg0 : i32) {
+ %0 = scf.while () : () -> i32 {
+ %val = "test.val"() : () -> i32
+ %condition = arith.cmpi ne, %val, %arg0 : i32
+ scf.condition(%condition) %val : i32
+ } do {
+ ^bb0(%val2: i32):
+ %condition2 = arith.cmpi ne, %val2, %arg0 : i32
+ %negcondition2 = arith.cmpi eq, %val2, %arg0 : i32
+ "test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
+ scf.yield
+ }
+ return
+}
+// CHECK-DAG: %[[true:.+]] = arith.constant true
+// CHECK-DAG: %[[false:.+]] = arith.constant false
+// CHECK-DAG: %{{.+}} = scf.while : () -> i32 {
+// CHECK-NEXT: %[[val:.+]] = "test.val"
+// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32
+// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%arg1: i32): // no predecessors
+// CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+
+// CHECK-LABEL: @while_cmp_rhs
+func @while_cmp_rhs(%arg0 : i32) {
+ %0 = scf.while () : () -> i32 {
+ %val = "test.val"() : () -> i32
+ %condition = arith.cmpi ne, %arg0, %val : i32
+ scf.condition(%condition) %val : i32
+ } do {
+ ^bb0(%val2: i32):
+ %condition2 = arith.cmpi ne, %arg0, %val2 : i32
+ %negcondition2 = arith.cmpi eq, %arg0, %val2 : i32
+ "test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
+ scf.yield
+ }
+ return
+}
+// CHECK-DAG: %[[true:.+]] = arith.constant true
+// CHECK-DAG: %[[false:.+]] = arith.constant false
+// CHECK-DAG: %{{.+}} = scf.while : () -> i32 {
+// CHECK-NEXT: %[[val:.+]] = "test.val"
+// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %arg0, %[[val]] : i32
+// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%arg1: i32): // no predecessors
+// CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+
// -----
// CHECK-LABEL: @combineIfs
More information about the Mlir-commits
mailing list