[Mlir-commits] [mlir] 8e211bf - [MLIR][SCF] Assume uses of condition in the body of scf.while is true

William S. Moses llvmlistbot at llvm.org
Tue May 4 08:40:43 PDT 2021


Author: William S. Moses
Date: 2021-05-04T11:39:07-04:00
New Revision: 8e211bf1c895a31b3e9f49014b5494d8e1dabcf6

URL: https://github.com/llvm/llvm-project/commit/8e211bf1c895a31b3e9f49014b5494d8e1dabcf6
DIFF: https://github.com/llvm/llvm-project/commit/8e211bf1c895a31b3e9f49014b5494d8e1dabcf6.diff

LOG: [MLIR][SCF] Assume uses of condition in the body of scf.while is true

Differential Revision: https://reviews.llvm.org/D101801

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 28348f083f167..c3c64e04ae08c 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -586,7 +586,11 @@ def WhileOp : SCF_Op<"while",
 
   let extraClassDeclaration = [{
     OperandRange getSuccessorEntryOperands(unsigned index);
+    ConditionOp getConditionOp();
+    Block::BlockArgListType getAfterArguments();
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 141b8802eff2e..c28e438fc8199 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1694,6 +1694,14 @@ OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
   return inits();
 }
 
+ConditionOp WhileOp::getConditionOp() {
+  return cast<ConditionOp>(before().front().getTerminator());
+}
+
+Block::BlockArgListType WhileOp::getAfterArguments() {
+  return after().front().getArguments();
+}
+
 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
                                   ArrayRef<Attribute> operands,
                                   SmallVectorImpl<RegionSuccessor> &regions) {
@@ -1835,6 +1843,62 @@ static LogicalResult verify(scf::WhileOp op) {
   return success(afterTerminator != nullptr);
 }
 
+namespace {
+/// Replace uses of the condition within the do block with true, since otherwise
+/// the block would not be evaluated.
+///
+/// scf.while (..) : (i1, ...) -> ... {
+///  %condition = call @evaluate_condition() : () -> i1
+///  scf.condition(%condition) %condition : i1, ...
+/// } do {
+/// ^bb0(%arg0: i1, ...):
+///    use(%arg0)
+///    ...
+///
+/// becomes
+/// scf.while (..) : (i1, ...) -> ... {
+///  %condition = call @evaluate_condition() : () -> i1
+///  scf.condition(%condition) %condition : i1, ...
+/// } do {
+/// ^bb0(%arg0: i1, ...):
+///    use(%true)
+///    ...
+struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    auto term = op.getConditionOp();
+
+    // These variables serve to prevent creating duplicate constants
+    // and hold constant true or false values.
+    Value constantTrue = nullptr;
+
+    bool replaced = false;
+    for (auto yieldedAndBlockArgs :
+         llvm::zip(term.args(), op.getAfterArguments())) {
+      if (std::get<0>(yieldedAndBlockArgs) == term.condition()) {
+        if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
+          if (!constantTrue)
+            constantTrue = rewriter.create<mlir::ConstantOp>(
+                op.getLoc(), term.condition().getType(),
+                rewriter.getBoolAttr(true));
+
+          std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
+          replaced = true;
+        }
+      }
+    }
+    return success(replaced);
+  }
+};
+} // namespace
+
+void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context) {
+  results.insert<WhileConditionTruth>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 4dee3825d8709..3ba8e8023155e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -724,3 +724,26 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 // CHECK-NEXT:       scf.yield %[[sv2]] : i32
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[if]], %arg1 : i32, i64
+
+
+// CHECK-LABEL: @while_cond_true
+func @while_cond_true() {
+  %0 = scf.while () : () -> i1 {
+    %condition = "test.condition"() : () -> i1
+    scf.condition(%condition) %condition : i1
+  } do {
+  ^bb0(%arg0: i1):
+    "test.use"(%arg0) : (i1) -> ()
+    scf.yield
+  }
+  return
+}
+// CHECK-NEXT:         %[[true:.+]] = constant true
+// CHECK-NEXT:         %{{.+}} = scf.while : () -> i1 {
+// CHECK-NEXT:           %[[cmp:.+]] = "test.condition"() : () -> i1
+// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[cmp]] : i1
+// CHECK-NEXT:         } do {
+// CHECK-NEXT:         ^bb0(%arg0: i1):  // no predecessors
+// CHECK-NEXT:           "test.use"(%[[true]]) : (i1) -> ()
+// CHECK-NEXT:           scf.yield
+// CHECK-NEXT:         }


        


More information about the Mlir-commits mailing list