[Mlir-commits] [mlir] d23fa4f - [MLIR][SCF] Remove unused arguments to whileop
William S. Moses
llvmlistbot at llvm.org
Tue Jan 11 17:18:13 PST 2022
Author: William S. Moses
Date: 2022-01-11T20:18:08-05:00
New Revision: d23fa4f2f1319039b7d939a4bf68c493ba26b07a
URL: https://github.com/llvm/llvm-project/commit/d23fa4f2f1319039b7d939a4bf68c493ba26b07a
DIFF: https://github.com/llvm/llvm-project/commit/d23fa4f2f1319039b7d939a4bf68c493ba26b07a.diff
LOG: [MLIR][SCF] Remove unused arguments to whileop
Canonicalize away unused arguments to the before region of a whileOp
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D117059
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index c8e51692252d6..e060dac6ce211 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -686,6 +686,8 @@ def WhileOp : SCF_Op<"while",
let extraClassDeclaration = [{
OperandRange getSuccessorEntryOperands(unsigned index);
ConditionOp getConditionOp();
+ YieldOp getYieldOp();
+ Block::BlockArgListType getBeforeArguments();
Block::BlockArgListType getAfterArguments();
}];
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 6b9fe80bbccdd..5fe73a8d57c51 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -2171,6 +2171,14 @@ ConditionOp WhileOp::getConditionOp() {
return cast<ConditionOp>(getBefore().front().getTerminator());
}
+YieldOp WhileOp::getYieldOp() {
+ return cast<YieldOp>(getAfter().front().getTerminator());
+}
+
+Block::BlockArgListType WhileOp::getBeforeArguments() {
+ return getBefore().front().getArguments();
+}
+
Block::BlockArgListType WhileOp::getAfterArguments() {
return getAfter().front().getArguments();
}
@@ -2508,11 +2516,60 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
return success(changed);
}
};
+
+struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp op,
+ PatternRewriter &rewriter) const override {
+
+ if (!llvm::any_of(op.getBeforeArguments(),
+ [](Value arg) { return arg.use_empty(); }))
+ return failure();
+
+ YieldOp yield = op.getYieldOp();
+
+ // Collect results mapping, new terminator args and new result types.
+ SmallVector<Value> newYields;
+ SmallVector<Value> newInits;
+ SmallVector<unsigned> argsToErase;
+ for (const auto &it : llvm::enumerate(llvm::zip(
+ op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
+ Value beforeArg = std::get<0>(it.value());
+ Value yieldValue = std::get<1>(it.value());
+ Value initValue = std::get<2>(it.value());
+ if (beforeArg.use_empty()) {
+ argsToErase.push_back(it.index());
+ } else {
+ newYields.emplace_back(yieldValue);
+ newInits.emplace_back(initValue);
+ }
+ }
+
+ if (argsToErase.size() == 0)
+ return failure();
+
+ rewriter.startRootUpdate(op);
+ op.getBefore().front().eraseArguments(argsToErase);
+ rewriter.finalizeRootUpdate(op);
+
+ WhileOp replacement =
+ rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
+ replacement.getBefore().takeBody(op.getBefore());
+ replacement.getAfter().takeBody(op.getAfter());
+ rewriter.replaceOp(op, replacement.getResults());
+
+ rewriter.setInsertionPoint(yield);
+ rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
+ return success();
+ }
+};
} // namespace
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond>(context);
+ results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond,
+ WhileUnusedArg>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2b3a31e833a3d..65c20542fb2f6 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -846,6 +846,30 @@ func @while_cond_true() -> i1 {
// -----
+// CHECK-LABEL: @while_unused_arg
+func @while_unused_arg(%x : i32, %y : f64) -> i32 {
+ %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
+ %condition = "test.condition"(%arg1) : (i32) -> i1
+ scf.condition(%condition) %arg1 : i32
+ } do {
+ ^bb0(%arg1: i32):
+ %next = "test.use"(%arg1) : (i32) -> (i32)
+ scf.yield %next, %y : i32, f64
+ }
+ return %0 : i32
+}
+// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.+]] = %{{.*}}) : (i32) -> i32 {
+// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
+// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%[[post:.+]]: i32): // no predecessors
+// CHECK-NEXT: %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32
+// CHECK-NEXT: scf.yield %[[next]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[res]] : i32
+
+// -----
+
// CHECK-LABEL: @while_unused_result
func @while_unused_result() -> i32 {
%0:2 = scf.while () : () -> (i32, i64) {
More information about the Mlir-commits
mailing list