[Mlir-commits] [mlir] d23fa4f - [MLIR][SCF] Remove unused arguments to whileop

William S. Moses llvmlistbot at llvm.org
Tue Jan 11 17:18:13 PST 2022


Author: William S. Moses
Date: 2022-01-11T20:18:08-05:00
New Revision: d23fa4f2f1319039b7d939a4bf68c493ba26b07a

URL: https://github.com/llvm/llvm-project/commit/d23fa4f2f1319039b7d939a4bf68c493ba26b07a
DIFF: https://github.com/llvm/llvm-project/commit/d23fa4f2f1319039b7d939a4bf68c493ba26b07a.diff

LOG: [MLIR][SCF] Remove unused arguments to whileop

Canonicalize away unused arguments to the before region of a whileOp

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D117059

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index c8e51692252d6..e060dac6ce211 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -686,6 +686,8 @@ def WhileOp : SCF_Op<"while",
   let extraClassDeclaration = [{
     OperandRange getSuccessorEntryOperands(unsigned index);
     ConditionOp getConditionOp();
+    YieldOp getYieldOp();
+    Block::BlockArgListType getBeforeArguments();
     Block::BlockArgListType getAfterArguments();
   }];
 

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 6b9fe80bbccdd..5fe73a8d57c51 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -2171,6 +2171,14 @@ ConditionOp WhileOp::getConditionOp() {
   return cast<ConditionOp>(getBefore().front().getTerminator());
 }
 
+YieldOp WhileOp::getYieldOp() {
+  return cast<YieldOp>(getAfter().front().getTerminator());
+}
+
+Block::BlockArgListType WhileOp::getBeforeArguments() {
+  return getBefore().front().getArguments();
+}
+
 Block::BlockArgListType WhileOp::getAfterArguments() {
   return getAfter().front().getArguments();
 }
@@ -2508,11 +2516,60 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
     return success(changed);
   }
 };
+
+struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+
+    if (!llvm::any_of(op.getBeforeArguments(),
+                      [](Value arg) { return arg.use_empty(); }))
+      return failure();
+
+    YieldOp yield = op.getYieldOp();
+
+    // Collect results mapping, new terminator args and new result types.
+    SmallVector<Value> newYields;
+    SmallVector<Value> newInits;
+    SmallVector<unsigned> argsToErase;
+    for (const auto &it : llvm::enumerate(llvm::zip(
+             op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
+      Value beforeArg = std::get<0>(it.value());
+      Value yieldValue = std::get<1>(it.value());
+      Value initValue = std::get<2>(it.value());
+      if (beforeArg.use_empty()) {
+        argsToErase.push_back(it.index());
+      } else {
+        newYields.emplace_back(yieldValue);
+        newInits.emplace_back(initValue);
+      }
+    }
+
+    if (argsToErase.size() == 0)
+      return failure();
+
+    rewriter.startRootUpdate(op);
+    op.getBefore().front().eraseArguments(argsToErase);
+    rewriter.finalizeRootUpdate(op);
+
+    WhileOp replacement =
+        rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
+    replacement.getBefore().takeBody(op.getBefore());
+    replacement.getAfter().takeBody(op.getAfter());
+    rewriter.replaceOp(op, replacement.getResults());
+
+    rewriter.setInsertionPoint(yield);
+    rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
+    return success();
+  }
+};
 } // namespace
 
 void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context) {
-  results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond>(context);
+  results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond,
+                 WhileUnusedArg>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2b3a31e833a3d..65c20542fb2f6 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -846,6 +846,30 @@ func @while_cond_true() -> i1 {
 
 // -----
 
+// CHECK-LABEL: @while_unused_arg
+func @while_unused_arg(%x : i32, %y : f64) -> i32 {
+  %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
+    %condition = "test.condition"(%arg1) : (i32) -> i1
+    scf.condition(%condition) %arg1 : i32
+  } do {
+  ^bb0(%arg1: i32):
+    %next = "test.use"(%arg1) : (i32) -> (i32)
+    scf.yield %next, %y : i32, f64
+  }
+  return %0 : i32
+}
+// CHECK-NEXT:         %[[res:.*]] = scf.while (%[[arg2:.+]] = %{{.*}}) : (i32) -> i32 {
+// CHECK-NEXT:           %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
+// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[arg2]] : i32
+// CHECK-NEXT:         } do {
+// CHECK-NEXT:         ^bb0(%[[post:.+]]: i32):  // no predecessors
+// CHECK-NEXT:           %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32
+// CHECK-NEXT:           scf.yield %[[next]] : i32
+// CHECK-NEXT:         }
+// CHECK-NEXT:         return %[[res]] : i32
+
+// -----
+
 // CHECK-LABEL: @while_unused_result
 func @while_unused_result() -> i32 {
   %0:2 = scf.while () : () -> (i32, i64) {


        


More information about the Mlir-commits mailing list