[Mlir-commits] [mlir] c97b7bc - [mlir][scf] WhileOp patterns cleanup
Ivan Butygin
llvmlistbot at llvm.org
Fri Apr 14 05:19:51 PDT 2023
Author: Ivan Butygin
Date: 2023-04-14T13:46:11+02:00
New Revision: c97b7bcf3e20b8423692e38dd0a1f249e6215bd8
URL: https://github.com/llvm/llvm-project/commit/c97b7bcf3e20b8423692e38dd0a1f249e6215bd8
DIFF: https://github.com/llvm/llvm-project/commit/c97b7bcf3e20b8423692e38dd0a1f249e6215bd8.diff
LOG: [mlir][scf] WhileOp patterns cleanup
Fix review comments from https://reviews.llvm.org/D146252
Merge `WhileRemoveUnusedArgs` pattern with (unused) `WhileUnusedArg`,
use `getConditionOp`, use `SmallPtrSet` and early check, move tests
Differential Revision: https://reviews.llvm.org/D148256
Added:
Modified:
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ed6d9f25e296c..06d4addae84f0 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -22,6 +22,7 @@
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -3738,7 +3739,8 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
}
};
-struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
+/// Remove unused init/yield args.
+struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
LogicalResult matchAndRewrite(WhileOp op,
@@ -3746,42 +3748,52 @@ struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
if (!llvm::any_of(op.getBeforeArguments(),
[](Value arg) { return arg.use_empty(); }))
- return failure();
+ return rewriter.notifyMatchFailure(op, "No args to remove");
YieldOp yield = op.getYieldOp();
// Collect results mapping, new terminator args and new result types.
SmallVector<Value> newYields;
SmallVector<Value> newInits;
- llvm::BitVector argsToErase(op.getBeforeArguments().size());
- for (const auto &it : llvm::enumerate(llvm::zip(
- op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
- Value beforeArg = std::get<0>(it.value());
- Value yieldValue = std::get<1>(it.value());
- Value initValue = std::get<2>(it.value());
+ llvm::BitVector argsToErase;
+
+ size_t argsCount = op.getBeforeArguments().size();
+ newYields.reserve(argsCount);
+ newInits.reserve(argsCount);
+ argsToErase.reserve(argsCount);
+ for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
+ op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
if (beforeArg.use_empty()) {
- argsToErase.set(it.index());
+ argsToErase.push_back(true);
} else {
+ argsToErase.push_back(false);
newYields.emplace_back(yieldValue);
newInits.emplace_back(initValue);
}
}
- if (argsToErase.none())
- return failure();
+ Block &beforeBlock = op.getBefore().front();
+ Block &afterBlock = op.getAfter().front();
- rewriter.startRootUpdate(op);
- op.getBefore().front().eraseArguments(argsToErase);
- rewriter.finalizeRootUpdate(op);
+ beforeBlock.eraseArguments(argsToErase);
- WhileOp replacement =
- rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
- replacement.getBefore().takeBody(op.getBefore());
- replacement.getAfter().takeBody(op.getAfter());
- rewriter.replaceOp(op, replacement.getResults());
+ Location loc = op.getLoc();
+ auto newWhileOp =
+ rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
+ /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
+ Block &newBeforeBlock = newWhileOp.getBefore().front();
+ Block &newAfterBlock = newWhileOp.getAfter().front();
+ OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
+
+ rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
+ newBeforeBlock.getArguments());
+ rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
+ newAfterBlock.getArguments());
+
+ rewriter.replaceOp(op, newWhileOp.getResults());
return success();
}
};
@@ -3792,14 +3804,21 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
- Block &beforeBlock = op.getBefore().front();
- Block &afterBlock = op.getAfter().front();
-
- auto condOp = cast<ConditionOp>(beforeBlock.getTerminator());
+ ConditionOp condOp = op.getConditionOp();
ValueRange condOpArgs = condOp.getArgs();
+
+ llvm::SmallPtrSet<Value, 8> argsSet;
+ for (Value arg : condOpArgs)
+ argsSet.insert(arg);
+
+ if (argsSet.size() == condOpArgs.size())
+ return rewriter.notifyMatchFailure(op, "No results to remove");
+
llvm::SmallDenseMap<Value, unsigned> argsMap;
SmallVector<Value> newArgs;
- for (auto arg : condOpArgs) {
+ argsMap.reserve(condOpArgs.size());
+ newArgs.reserve(condOpArgs.size());
+ for (Value arg : condOpArgs) {
if (!argsMap.count(arg)) {
auto pos = static_cast<unsigned>(argsMap.size());
argsMap.insert({arg, pos});
@@ -3807,9 +3826,6 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
}
}
- if (argsMap.size() == condOpArgs.size())
- return rewriter.notifyMatchFailure(op, "No results to remove");
-
ValueRange argsRange(newArgs);
Location loc = op.getLoc();
@@ -3834,64 +3850,13 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
argsRange);
- rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
- newBeforeBlock.getArguments());
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
- rewriter.replaceOp(op, resultsMapping);
- return success();
- }
-};
-
-/// Remove unused init/yield args.
-struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern<WhileOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
Block &beforeBlock = op.getBefore().front();
Block &afterBlock = op.getAfter().front();
- auto yield = cast<YieldOp>(afterBlock.getTerminator());
-
- llvm::BitVector argsToRemove;
- SmallVector<Value> newInits;
- SmallVector<Value> newYieldArgs;
-
- bool changed = false;
- for (auto &&[arg, init, yieldArg] : llvm::zip(
- beforeBlock.getArguments(), op.getInits(), yield.getResults())) {
- bool empty = arg.use_empty();
- argsToRemove.push_back(empty);
- if (empty) {
- changed = true;
- continue;
- }
-
- newInits.emplace_back(init);
- newYieldArgs.emplace_back(yieldArg);
- }
-
- if (!changed)
- return rewriter.notifyMatchFailure(op, "No args to remove");
-
- beforeBlock.eraseArguments(argsToRemove);
-
- Location loc = op.getLoc();
- auto newWhileOp =
- rewriter.create<WhileOp>(loc, op->getResultTypes(), newInits,
- /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
- Block &newBeforeBlock = newWhileOp.getBefore().front();
- Block &newAfterBlock = newWhileOp.getAfter().front();
-
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(yield);
- rewriter.replaceOpWithNewOp<YieldOp>(yield, newYieldArgs);
-
rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
newBeforeBlock.getArguments());
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
- newAfterBlock.getArguments());
- rewriter.replaceOp(op, newWhileOp.getResults());
+ rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
+ rewriter.replaceOp(op, resultsMapping);
return success();
}
};
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 355d343442c46..e55be6127fe24 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1019,30 +1019,6 @@ func.func @while_cond_true() -> i1 {
// -----
-// CHECK-LABEL: @while_unused_arg
-func.func @while_unused_arg(%x : i32, %y : f64) -> i32 {
- %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
- %condition = "test.condition"(%arg1) : (i32) -> i1
- scf.condition(%condition) %arg1 : i32
- } do {
- ^bb0(%arg1: i32):
- %next = "test.use"(%arg1) : (i32) -> (i32)
- scf.yield %next, %y : i32, f64
- }
- return %0 : i32
-}
-// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.+]] = %{{.*}}) : (i32) -> i32 {
-// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
-// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32
-// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[post:.+]]: i32):
-// CHECK-NEXT: %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32
-// CHECK-NEXT: scf.yield %[[next]] : i32
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[res]] : i32
-
-// -----
-
// CHECK-LABEL: @invariant_loop_args_in_same_order
// CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor<i32>)
func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
@@ -1221,10 +1197,36 @@ func.func @while_duplicated_res() -> (i32, i32) {
// CHECK: }
// CHECK: return %[[RES]], %[[RES]] : i32, i32
+
+// -----
+
+// CHECK-LABEL: @while_unused_arg1
+func.func @while_unused_arg1(%x : i32, %y : f64) -> i32 {
+ %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
+ %condition = "test.condition"(%arg1) : (i32) -> i1
+ scf.condition(%condition) %arg1 : i32
+ } do {
+ ^bb0(%arg1: i32):
+ %next = "test.use"(%arg1) : (i32) -> (i32)
+ scf.yield %next, %y : i32, f64
+ }
+ return %0 : i32
+}
+// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.*]] = %{{.*}}) : (i32) -> i32 {
+// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
+// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%[[post:.*]]: i32):
+// CHECK-NEXT: %[[next:.*]] = "test.use"(%[[post]]) : (i32) -> i32
+// CHECK-NEXT: scf.yield %[[next]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[res]] : i32
+
+
// -----
-// CHECK-LABEL: @while_unused_arg
-func.func @while_unused_arg(%val0: i32) -> i32 {
+// CHECK-LABEL: @while_unused_arg2
+func.func @while_unused_arg2(%val0: i32) -> i32 {
%0 = scf.while (%val1 = %val0) : (i32) -> i32 {
%val = "test.val"() : () -> i32
%condition = "test.condition"() : () -> i1
More information about the Mlir-commits
mailing list