[Mlir-commits] [mlir] c97b7bc - [mlir][scf] WhileOp patterns cleanup

Ivan Butygin llvmlistbot at llvm.org
Fri Apr 14 05:19:51 PDT 2023


Author: Ivan Butygin
Date: 2023-04-14T13:46:11+02:00
New Revision: c97b7bcf3e20b8423692e38dd0a1f249e6215bd8

URL: https://github.com/llvm/llvm-project/commit/c97b7bcf3e20b8423692e38dd0a1f249e6215bd8
DIFF: https://github.com/llvm/llvm-project/commit/c97b7bcf3e20b8423692e38dd0a1f249e6215bd8.diff

LOG: [mlir][scf] WhileOp patterns cleanup

Fix review comments from https://reviews.llvm.org/D146252
Merge `WhileRemoveUnusedArgs` pattern with (unused) `WhileUnusedArg`,
use `getConditionOp`, use `SmallPtrSet` and early check, move tests

Differential Revision: https://reviews.llvm.org/D148256

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ed6d9f25e296c..06d4addae84f0 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -3738,7 +3739,8 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
   }
 };
 
-struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
+/// Remove unused init/yield args.
+struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
   using OpRewritePattern<WhileOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(WhileOp op,
@@ -3746,42 +3748,52 @@ struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
 
     if (!llvm::any_of(op.getBeforeArguments(),
                       [](Value arg) { return arg.use_empty(); }))
-      return failure();
+      return rewriter.notifyMatchFailure(op, "No args to remove");
 
     YieldOp yield = op.getYieldOp();
 
     // Collect results mapping, new terminator args and new result types.
     SmallVector<Value> newYields;
     SmallVector<Value> newInits;
-    llvm::BitVector argsToErase(op.getBeforeArguments().size());
-    for (const auto &it : llvm::enumerate(llvm::zip(
-             op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
-      Value beforeArg = std::get<0>(it.value());
-      Value yieldValue = std::get<1>(it.value());
-      Value initValue = std::get<2>(it.value());
+    llvm::BitVector argsToErase;
+
+    size_t argsCount = op.getBeforeArguments().size();
+    newYields.reserve(argsCount);
+    newInits.reserve(argsCount);
+    argsToErase.reserve(argsCount);
+    for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
+             op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
       if (beforeArg.use_empty()) {
-        argsToErase.set(it.index());
+        argsToErase.push_back(true);
       } else {
+        argsToErase.push_back(false);
         newYields.emplace_back(yieldValue);
         newInits.emplace_back(initValue);
       }
     }
 
-    if (argsToErase.none())
-      return failure();
+    Block &beforeBlock = op.getBefore().front();
+    Block &afterBlock = op.getAfter().front();
 
-    rewriter.startRootUpdate(op);
-    op.getBefore().front().eraseArguments(argsToErase);
-    rewriter.finalizeRootUpdate(op);
+    beforeBlock.eraseArguments(argsToErase);
 
-    WhileOp replacement =
-        rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
-    replacement.getBefore().takeBody(op.getBefore());
-    replacement.getAfter().takeBody(op.getAfter());
-    rewriter.replaceOp(op, replacement.getResults());
+    Location loc = op.getLoc();
+    auto newWhileOp =
+        rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
+                                 /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
+    Block &newBeforeBlock = newWhileOp.getBefore().front();
+    Block &newAfterBlock = newWhileOp.getAfter().front();
 
+    OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(yield);
     rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
+
+    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
+                         newBeforeBlock.getArguments());
+    rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
+                         newAfterBlock.getArguments());
+
+    rewriter.replaceOp(op, newWhileOp.getResults());
     return success();
   }
 };
@@ -3792,14 +3804,21 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
 
   LogicalResult matchAndRewrite(WhileOp op,
                                 PatternRewriter &rewriter) const override {
-    Block &beforeBlock = op.getBefore().front();
-    Block &afterBlock = op.getAfter().front();
-
-    auto condOp = cast<ConditionOp>(beforeBlock.getTerminator());
+    ConditionOp condOp = op.getConditionOp();
     ValueRange condOpArgs = condOp.getArgs();
+
+    llvm::SmallPtrSet<Value, 8> argsSet;
+    for (Value arg : condOpArgs)
+      argsSet.insert(arg);
+
+    if (argsSet.size() == condOpArgs.size())
+      return rewriter.notifyMatchFailure(op, "No results to remove");
+
     llvm::SmallDenseMap<Value, unsigned> argsMap;
     SmallVector<Value> newArgs;
-    for (auto arg : condOpArgs) {
+    argsMap.reserve(condOpArgs.size());
+    newArgs.reserve(condOpArgs.size());
+    for (Value arg : condOpArgs) {
       if (!argsMap.count(arg)) {
         auto pos = static_cast<unsigned>(argsMap.size());
         argsMap.insert({arg, pos});
@@ -3807,9 +3826,6 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
       }
     }
 
-    if (argsMap.size() == condOpArgs.size())
-      return rewriter.notifyMatchFailure(op, "No results to remove");
-
     ValueRange argsRange(newArgs);
 
     Location loc = op.getLoc();
@@ -3834,64 +3850,13 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
     rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
                                              argsRange);
 
-    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
-                         newBeforeBlock.getArguments());
-    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
-    rewriter.replaceOp(op, resultsMapping);
-    return success();
-  }
-};
-
-/// Remove unused init/yield args.
-struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern<WhileOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
     Block &beforeBlock = op.getBefore().front();
     Block &afterBlock = op.getAfter().front();
 
-    auto yield = cast<YieldOp>(afterBlock.getTerminator());
-
-    llvm::BitVector argsToRemove;
-    SmallVector<Value> newInits;
-    SmallVector<Value> newYieldArgs;
-
-    bool changed = false;
-    for (auto &&[arg, init, yieldArg] : llvm::zip(
-             beforeBlock.getArguments(), op.getInits(), yield.getResults())) {
-      bool empty = arg.use_empty();
-      argsToRemove.push_back(empty);
-      if (empty) {
-        changed = true;
-        continue;
-      }
-
-      newInits.emplace_back(init);
-      newYieldArgs.emplace_back(yieldArg);
-    }
-
-    if (!changed)
-      return rewriter.notifyMatchFailure(op, "No args to remove");
-
-    beforeBlock.eraseArguments(argsToRemove);
-
-    Location loc = op.getLoc();
-    auto newWhileOp =
-        rewriter.create<WhileOp>(loc, op->getResultTypes(), newInits,
-                                 /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
-    Block &newBeforeBlock = newWhileOp.getBefore().front();
-    Block &newAfterBlock = newWhileOp.getAfter().front();
-
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(yield);
-    rewriter.replaceOpWithNewOp<YieldOp>(yield, newYieldArgs);
-
     rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
                          newBeforeBlock.getArguments());
-    rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
-                         newAfterBlock.getArguments());
-    rewriter.replaceOp(op, newWhileOp.getResults());
+    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
+    rewriter.replaceOp(op, resultsMapping);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 355d343442c46..e55be6127fe24 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1019,30 +1019,6 @@ func.func @while_cond_true() -> i1 {
 
 // -----
 
-// CHECK-LABEL: @while_unused_arg
-func.func @while_unused_arg(%x : i32, %y : f64) -> i32 {
-  %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
-    %condition = "test.condition"(%arg1) : (i32) -> i1
-    scf.condition(%condition) %arg1 : i32
-  } do {
-  ^bb0(%arg1: i32):
-    %next = "test.use"(%arg1) : (i32) -> (i32)
-    scf.yield %next, %y : i32, f64
-  }
-  return %0 : i32
-}
-// CHECK-NEXT:         %[[res:.*]] = scf.while (%[[arg2:.+]] = %{{.*}}) : (i32) -> i32 {
-// CHECK-NEXT:           %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
-// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[arg2]] : i32
-// CHECK-NEXT:         } do {
-// CHECK-NEXT:         ^bb0(%[[post:.+]]: i32):
-// CHECK-NEXT:           %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32
-// CHECK-NEXT:           scf.yield %[[next]] : i32
-// CHECK-NEXT:         }
-// CHECK-NEXT:         return %[[res]] : i32
-
-// -----
-
 // CHECK-LABEL: @invariant_loop_args_in_same_order
 // CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor<i32>)
 func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
@@ -1221,10 +1197,36 @@ func.func @while_duplicated_res() -> (i32, i32) {
 // CHECK:         }
 // CHECK:         return %[[RES]], %[[RES]] : i32, i32
 
+
+// -----
+
+// CHECK-LABEL: @while_unused_arg1
+func.func @while_unused_arg1(%x : i32, %y : f64) -> i32 {
+  %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
+    %condition = "test.condition"(%arg1) : (i32) -> i1
+    scf.condition(%condition) %arg1 : i32
+  } do {
+  ^bb0(%arg1: i32):
+    %next = "test.use"(%arg1) : (i32) -> (i32)
+    scf.yield %next, %y : i32, f64
+  }
+  return %0 : i32
+}
+// CHECK-NEXT:         %[[res:.*]] = scf.while (%[[arg2:.*]] = %{{.*}}) : (i32) -> i32 {
+// CHECK-NEXT:           %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
+// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[arg2]] : i32
+// CHECK-NEXT:         } do {
+// CHECK-NEXT:         ^bb0(%[[post:.*]]: i32):
+// CHECK-NEXT:           %[[next:.*]] = "test.use"(%[[post]]) : (i32) -> i32
+// CHECK-NEXT:           scf.yield %[[next]] : i32
+// CHECK-NEXT:         }
+// CHECK-NEXT:         return %[[res]] : i32
+
+
 // -----
 
-// CHECK-LABEL: @while_unused_arg
-func.func @while_unused_arg(%val0: i32) -> i32 {
+// CHECK-LABEL: @while_unused_arg2
+func.func @while_unused_arg2(%val0: i32) -> i32 {
   %0 = scf.while (%val1 = %val0) : (i32) -> i32 {
     %val = "test.val"() : () -> i32
     %condition = "test.condition"() : () -> i1


        


More information about the Mlir-commits mailing list