[Mlir-commits] [mlir] 8e211bf - [MLIR][SCF] Assume uses of condition in the body of scf.while is true
William S. Moses
llvmlistbot at llvm.org
Tue May 4 08:40:43 PDT 2021
Author: William S. Moses
Date: 2021-05-04T11:39:07-04:00
New Revision: 8e211bf1c895a31b3e9f49014b5494d8e1dabcf6
URL: https://github.com/llvm/llvm-project/commit/8e211bf1c895a31b3e9f49014b5494d8e1dabcf6
DIFF: https://github.com/llvm/llvm-project/commit/8e211bf1c895a31b3e9f49014b5494d8e1dabcf6.diff
LOG: [MLIR][SCF] Assume uses of condition in the body of scf.while is true
Differential Revision: https://reviews.llvm.org/D101801
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 28348f083f167..c3c64e04ae08c 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -586,7 +586,11 @@ def WhileOp : SCF_Op<"while",
let extraClassDeclaration = [{
OperandRange getSuccessorEntryOperands(unsigned index);
+ ConditionOp getConditionOp();
+ Block::BlockArgListType getAfterArguments();
}];
+
+ let hasCanonicalizer = 1;
}
def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 141b8802eff2e..c28e438fc8199 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1694,6 +1694,14 @@ OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
return inits();
}
+ConditionOp WhileOp::getConditionOp() {
+ return cast<ConditionOp>(before().front().getTerminator());
+}
+
+Block::BlockArgListType WhileOp::getAfterArguments() {
+ return after().front().getArguments();
+}
+
void WhileOp::getSuccessorRegions(Optional<unsigned> index,
ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
@@ -1835,6 +1843,62 @@ static LogicalResult verify(scf::WhileOp op) {
return success(afterTerminator != nullptr);
}
+namespace {
+/// Replace uses of the condition within the do block with true, since otherwise
+/// the block would not be evaluated.
+///
+/// scf.while (..) : (i1, ...) -> ... {
+/// %condition = call @evaluate_condition() : () -> i1
+/// scf.condition(%condition) %condition : i1, ...
+/// } do {
+/// ^bb0(%arg0: i1, ...):
+/// use(%arg0)
+/// ...
+///
+/// becomes
+/// scf.while (..) : (i1, ...) -> ... {
+/// %condition = call @evaluate_condition() : () -> i1
+/// scf.condition(%condition) %condition : i1, ...
+/// } do {
+/// ^bb0(%arg0: i1, ...):
+/// use(%true)
+/// ...
+struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp op,
+ PatternRewriter &rewriter) const override {
+ auto term = op.getConditionOp();
+
+ // These variables serve to prevent creating duplicate constants
+ // and hold constant true or false values.
+ Value constantTrue = nullptr;
+
+ bool replaced = false;
+ for (auto yieldedAndBlockArgs :
+ llvm::zip(term.args(), op.getAfterArguments())) {
+ if (std::get<0>(yieldedAndBlockArgs) == term.condition()) {
+ if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
+ if (!constantTrue)
+ constantTrue = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), term.condition().getType(),
+ rewriter.getBoolAttr(true));
+
+ std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
+ replaced = true;
+ }
+ }
+ }
+ return success(replaced);
+ }
+};
+} // namespace
+
+void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<WhileConditionTruth>(context);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 4dee3825d8709..3ba8e8023155e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -724,3 +724,26 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// CHECK-NEXT: scf.yield %[[sv2]] : i32
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]], %arg1 : i32, i64
+
+
+// CHECK-LABEL: @while_cond_true
+func @while_cond_true() {
+ %0 = scf.while () : () -> i1 {
+ %condition = "test.condition"() : () -> i1
+ scf.condition(%condition) %condition : i1
+ } do {
+ ^bb0(%arg0: i1):
+ "test.use"(%arg0) : (i1) -> ()
+ scf.yield
+ }
+ return
+}
+// CHECK-NEXT: %[[true:.+]] = constant true
+// CHECK-NEXT: %{{.+}} = scf.while : () -> i1 {
+// CHECK-NEXT: %[[cmp:.+]] = "test.condition"() : () -> i1
+// CHECK-NEXT: scf.condition(%[[cmp]]) %[[cmp]] : i1
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%arg0: i1): // no predecessors
+// CHECK-NEXT: "test.use"(%[[true]]) : (i1) -> ()
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
More information about the Mlir-commits
mailing list