[Mlir-commits] [mlir] 7f5d9bf - [mlir][scf] Canonicalize scf.while with unused results

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 24 00:40:17 PST 2021


Author: Butygin
Date: 2021-11-24T11:11:22+03:00
New Revision: 7f5d9bf13a7d872dc173d09598743c006f59d6b8

URL: https://github.com/llvm/llvm-project/commit/7f5d9bf13a7d872dc173d09598743c006f59d6b8
DIFF: https://github.com/llvm/llvm-project/commit/7f5d9bf13a7d872dc173d09598743c006f59d6b8.diff

LOG: [mlir][scf] Canonicalize scf.while with unused results

Differential Revision: https://reviews.llvm.org/D114291

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 57ad2f412852f..358d28bcc8367 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -2255,11 +2255,102 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
     return success(replaced);
   }
 };
+
+/// Remove WhileOp results that are also unused in 'after' block.
+///
+///  %0:2 = scf.while () : () -> (i32, i64) {
+///    %condition = "test.condition"() : () -> i1
+///    %v1 = "test.get_some_value"() : () -> i32
+///    %v2 = "test.get_some_value"() : () -> i64
+///    scf.condition(%condition) %v1, %v2 : i32, i64
+///  } do {
+///  ^bb0(%arg0: i32, %arg1: i64):
+///    "test.use"(%arg0) : (i32) -> ()
+///    scf.yield
+///  }
+///  return %0#0 : i32
+///
+/// becomes
+///  %0 = scf.while () : () -> (i32) {
+///    %condition = "test.condition"() : () -> i1
+///    %v1 = "test.get_some_value"() : () -> i32
+///    %v2 = "test.get_some_value"() : () -> i64
+///    scf.condition(%condition) %v1 : i32
+///  } do {
+///  ^bb0(%arg0: i32):
+///    "test.use"(%arg0) : (i32) -> ()
+///    scf.yield
+///  }
+///  return %0 : i32
+struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    auto term = op.getConditionOp();
+    auto afterArgs = op.getAfterArguments();
+    auto termArgs = term.args();
+
+    // Collect results mapping, new terminator args and new result types.
+    SmallVector<unsigned> newResultsIndices;
+    SmallVector<Type> newResultTypes;
+    SmallVector<Value> newTermArgs;
+    bool needUpdate = false;
+    for (auto it :
+         llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
+      auto i = static_cast<unsigned>(it.index());
+      Value result = std::get<0>(it.value());
+      Value afterArg = std::get<1>(it.value());
+      Value termArg = std::get<2>(it.value());
+      if (result.use_empty() && afterArg.use_empty()) {
+        needUpdate = true;
+      } else {
+        newResultsIndices.emplace_back(i);
+        newTermArgs.emplace_back(termArg);
+        newResultTypes.emplace_back(result.getType());
+      }
+    }
+
+    if (!needUpdate)
+      return failure();
+
+    {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(term);
+      rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(),
+                                               newTermArgs);
+    }
+
+    auto newWhile =
+        rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.inits());
+
+    Block &newAfterBlock = *rewriter.createBlock(
+        &newWhile.after(), /*insertPt*/ {}, newResultTypes);
+
+    // Build new results list and new after block args (unused entries will be
+    // null).
+    SmallVector<Value> newResults(op.getNumResults());
+    SmallVector<Value> newAfterBlockArgs(op.getNumResults());
+    for (auto it : llvm::enumerate(newResultsIndices)) {
+      newResults[it.value()] = newWhile.getResult(it.index());
+      newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
+    }
+
+    rewriter.inlineRegionBefore(op.before(), newWhile.before(),
+                                newWhile.before().begin());
+
+    Block &afterBlock = op.after().front();
+    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
+
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+};
 } // namespace
 
 void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context) {
-  results.insert<WhileConditionTruth>(context);
+  results.insert<WhileConditionTruth, WhileUnusedResult>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c8471b8adbc60..6f1815713819c 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -782,7 +782,7 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 // -----
 
 // CHECK-LABEL: @while_cond_true
-func @while_cond_true() {
+func @while_cond_true() -> i1 {
   %0 = scf.while () : () -> i1 {
     %condition = "test.condition"() : () -> i1
     scf.condition(%condition) %condition : i1
@@ -791,7 +791,7 @@ func @while_cond_true() {
     "test.use"(%arg0) : (i1) -> ()
     scf.yield
   }
-  return
+  return %0 : i1
 }
 // CHECK-NEXT:         %[[true:.+]] = arith.constant true
 // CHECK-NEXT:         %{{.+}} = scf.while : () -> i1 {
@@ -805,6 +805,34 @@ func @while_cond_true() {
 
 // -----
 
+// CHECK-LABEL: @while_unused_result
+func @while_unused_result() -> i32 {
+  %0:2 = scf.while () : () -> (i32, i64) {
+    %condition = "test.condition"() : () -> i1
+    %v1 = "test.get_some_value"() : () -> i32
+    %v2 = "test.get_some_value"() : () -> i64
+    scf.condition(%condition) %v1, %v2 : i32, i64
+  } do {
+  ^bb0(%arg0: i32, %arg1: i64):
+    "test.use"(%arg0) : (i32) -> ()
+    scf.yield
+  }
+  return %0#0 : i32
+}
+// CHECK-NEXT:         %[[res:.*]] = scf.while : () -> i32 {
+// CHECK-NEXT:           %[[cmp:.*]] = "test.condition"() : () -> i1
+// CHECK-NEXT:           %[[val:.*]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT:           %{{.*}} = "test.get_some_value"() : () -> i64
+// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[val]] : i32
+// CHECK-NEXT:         } do {
+// CHECK-NEXT:         ^bb0(%[[arg:.*]]: i32):  // no predecessors
+// CHECK-NEXT:           "test.use"(%[[arg]]) : (i32) -> ()
+// CHECK-NEXT:           scf.yield
+// CHECK-NEXT:         }
+// CHECK-NEXT:         return %[[res]] : i32
+
+// -----
+
 // CHECK-LABEL: @combineIfs
 func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
   %res = scf.if %arg0 -> i32 {


        


More information about the Mlir-commits mailing list