[Mlir-commits] [mlir] e78d341 - [mlir][scf] More WhileOp canonicalizations
Ivan Butygin
llvmlistbot at llvm.org
Wed Apr 12 07:40:53 PDT 2023
Author: Ivan Butygin
Date: 2023-04-12T16:39:05+02:00
New Revision: e78d341f936e52df92279cc5bff1bff0625b7d23
URL: https://github.com/llvm/llvm-project/commit/e78d341f936e52df92279cc5bff1bff0625b7d23
DIFF: https://github.com/llvm/llvm-project/commit/e78d341f936e52df92279cc5bff1bff0625b7d23.diff
LOG: [mlir][scf] More WhileOp canonicalizations
Remove duplicated ConditonOp args, remove unused init/yield args.
Differential Revision: https://reviews.llvm.org/D146252
Added:
Modified:
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 3eda0d68f5fbb..b43615c9f1933 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3785,13 +3785,129 @@ struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
return success();
}
};
+
+/// Remove duplicated ConditionOp args.
+struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp op,
+ PatternRewriter &rewriter) const override {
+ Block &beforeBlock = op.getBefore().front();
+ Block &afterBlock = op.getAfter().front();
+
+ auto condOp = cast<ConditionOp>(beforeBlock.getTerminator());
+ ValueRange condOpArgs = condOp.getArgs();
+ llvm::SmallDenseMap<Value, unsigned> argsMap;
+ SmallVector<Value> newArgs;
+ for (auto arg : condOpArgs) {
+ if (!argsMap.count(arg)) {
+ auto pos = static_cast<unsigned>(argsMap.size());
+ argsMap.insert({arg, pos});
+ newArgs.emplace_back(arg);
+ }
+ }
+
+ if (argsMap.size() == condOpArgs.size())
+ return rewriter.notifyMatchFailure(op, "No results to remove");
+
+ ValueRange argsRange(newArgs);
+ auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
+ // Nothing
+ };
+
+ Location loc = op.getLoc();
+ auto newWhileOp = rewriter.create<scf::WhileOp>(
+ loc, argsRange.getTypes(), op.getInits(), emptyBuilder, emptyBuilder);
+ Block &newBeforeBlock = newWhileOp.getBefore().front();
+ Block &newAfterBlock = newWhileOp.getAfter().front();
+
+ SmallVector<Value> afterArgsMapping;
+ SmallVector<Value> resultsMapping;
+ for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
+ auto it = argsMap.find(arg);
+ assert(it != argsMap.end());
+ auto pos = it->second;
+ afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
+ resultsMapping.emplace_back(newWhileOp->getResult(pos));
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(condOp);
+ rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
+ argsRange);
+
+ rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
+ newBeforeBlock.getArguments());
+ rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
+ rewriter.replaceOp(op, resultsMapping);
+ return success();
+ }
+};
+
+/// Remove unused init/yield args.
+struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern<WhileOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp op,
+ PatternRewriter &rewriter) const override {
+ Block &beforeBlock = op.getBefore().front();
+ Block &afterBlock = op.getAfter().front();
+
+ auto yield = cast<YieldOp>(afterBlock.getTerminator());
+
+ llvm::BitVector argsToRemove;
+ SmallVector<Value> newInits;
+ SmallVector<Value> newYieldArgs;
+
+ bool changed = false;
+ for (auto &&[arg, init, yieldArg] : llvm::zip(
+ beforeBlock.getArguments(), op.getInits(), yield.getResults())) {
+ bool empty = arg.use_empty();
+ argsToRemove.push_back(empty);
+ if (empty) {
+ changed = true;
+ continue;
+ }
+
+ newInits.emplace_back(init);
+ newYieldArgs.emplace_back(yieldArg);
+ }
+
+ if (!changed)
+ return rewriter.notifyMatchFailure(op, "No args to remove");
+
+ beforeBlock.eraseArguments(argsToRemove);
+
+ auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
+ // Nothing
+ };
+
+ Location loc = op.getLoc();
+ auto newWhileOp = rewriter.create<WhileOp>(
+ loc, op->getResultTypes(), newInits, emptyBuilder, emptyBuilder);
+ Block &newBeforeBlock = newWhileOp.getBefore().front();
+ Block &newAfterBlock = newWhileOp.getAfter().front();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(yield);
+ rewriter.replaceOpWithNewOp<YieldOp>(yield, newYieldArgs);
+
+ rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
+ newBeforeBlock.getArguments());
+ rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
+ newAfterBlock.getArguments());
+ rewriter.replaceOp(op, newWhileOp.getResults());
+ return success();
+ }
+};
} // namespace
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
- WhileCmpCond, WhileUnusedResult>(context);
+ WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
+ WhileRemoveUnusedArgs>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ec6e35a200075..355d343442c46 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1195,6 +1195,60 @@ func.func @while_cmp_rhs(%arg0 : i32) {
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
+// -----
+
+// CHECK-LABEL: @while_duplicated_res
+func.func @while_duplicated_res() -> (i32, i32) {
+ %0:2 = scf.while () : () -> (i32, i32) {
+ %val = "test.val"() : () -> i32
+ %condition = "test.condition"() : () -> i1
+ scf.condition(%condition) %val, %val : i32, i32
+ } do {
+ ^bb0(%val2: i32, %val3: i32):
+ "test.use"(%val2, %val3) : (i32, i32) -> ()
+ scf.yield
+ }
+ return %0#0, %0#1: i32, i32
+}
+// CHECK: %[[RES:.*]] = scf.while : () -> i32 {
+// CHECK: %[[VAL:.*]] = "test.val"() : () -> i32
+// CHECK: %[[COND:.*]] = "test.condition"() : () -> i1
+// CHECK: scf.condition(%[[COND]]) %[[VAL]] : i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[ARG:.*]]: i32):
+// CHECK: "test.use"(%[[ARG]], %[[ARG]]) : (i32, i32) -> ()
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: return %[[RES]], %[[RES]] : i32, i32
+
+// -----
+
+// CHECK-LABEL: @while_unused_arg
+func.func @while_unused_arg(%val0: i32) -> i32 {
+ %0 = scf.while (%val1 = %val0) : (i32) -> i32 {
+ %val = "test.val"() : () -> i32
+ %condition = "test.condition"() : () -> i1
+ scf.condition(%condition) %val: i32
+ } do {
+ ^bb0(%val2: i32):
+ "test.use"(%val2) : (i32) -> ()
+ %val1 = "test.val1"() : () -> i32
+ scf.yield %val1 : i32
+ }
+ return %0 : i32
+}
+// CHECK: %[[RES:.*]] = scf.while : () -> i32 {
+// CHECK: %[[VAL:.*]] = "test.val"() : () -> i32
+// CHECK: %[[COND:.*]] = "test.condition"() : () -> i1
+// CHECK: scf.condition(%[[COND]]) %[[VAL]] : i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[ARG:.*]]: i32):
+// CHECK: "test.use"(%[[ARG]]) : (i32) -> ()
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+
+
// -----
// CHECK-LABEL: @combineIfs
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 6a7bd40e3bbee..e23e86184c311 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -518,6 +518,8 @@ func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
%arg2: index) {
scf.while (%arg3 = %arg1) : (tensor<5xi1>) -> () {
%0 = tensor.extract %arg0[%arg2] : tensor<5xi1>
+ %1 = tensor.extract %arg3[%arg2] : tensor<5xi1>
+ "dummy.use"(%1) : (i1) -> ()
scf.condition(%0)
} do {
%0 = "dummy.some_op"() : () -> index
More information about the Mlir-commits
mailing list