[Mlir-commits] [mlir] 7f5d9bf - [mlir][scf] Canonicalize scf.while with unused results
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 24 00:40:17 PST 2021
Author: Butygin
Date: 2021-11-24T11:11:22+03:00
New Revision: 7f5d9bf13a7d872dc173d09598743c006f59d6b8
URL: https://github.com/llvm/llvm-project/commit/7f5d9bf13a7d872dc173d09598743c006f59d6b8
DIFF: https://github.com/llvm/llvm-project/commit/7f5d9bf13a7d872dc173d09598743c006f59d6b8.diff
LOG: [mlir][scf] Canonicalize scf.while with unused results
Differential Revision: https://reviews.llvm.org/D114291
Added:
Modified:
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 57ad2f412852f..358d28bcc8367 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -2255,11 +2255,102 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
return success(replaced);
}
};
+
+/// Remove WhileOp results that are also unused in 'after' block.
+///
+/// %0:2 = scf.while () : () -> (i32, i64) {
+/// %condition = "test.condition"() : () -> i1
+/// %v1 = "test.get_some_value"() : () -> i32
+/// %v2 = "test.get_some_value"() : () -> i64
+/// scf.condition(%condition) %v1, %v2 : i32, i64
+/// } do {
+/// ^bb0(%arg0: i32, %arg1: i64):
+/// "test.use"(%arg0) : (i32) -> ()
+/// scf.yield
+/// }
+/// return %0#0 : i32
+///
+/// becomes
+/// %0 = scf.while () : () -> (i32) {
+/// %condition = "test.condition"() : () -> i1
+/// %v1 = "test.get_some_value"() : () -> i32
+/// %v2 = "test.get_some_value"() : () -> i64
+/// scf.condition(%condition) %v1 : i32
+/// } do {
+/// ^bb0(%arg0: i32):
+/// "test.use"(%arg0) : (i32) -> ()
+/// scf.yield
+/// }
+/// return %0 : i32
+struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp op,
+ PatternRewriter &rewriter) const override {
+ auto term = op.getConditionOp();
+ auto afterArgs = op.getAfterArguments();
+ auto termArgs = term.args();
+
+ // Collect results mapping, new terminator args and new result types.
+ SmallVector<unsigned> newResultsIndices;
+ SmallVector<Type> newResultTypes;
+ SmallVector<Value> newTermArgs;
+ bool needUpdate = false;
+ for (auto it :
+ llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
+ auto i = static_cast<unsigned>(it.index());
+ Value result = std::get<0>(it.value());
+ Value afterArg = std::get<1>(it.value());
+ Value termArg = std::get<2>(it.value());
+ if (result.use_empty() && afterArg.use_empty()) {
+ needUpdate = true;
+ } else {
+ newResultsIndices.emplace_back(i);
+ newTermArgs.emplace_back(termArg);
+ newResultTypes.emplace_back(result.getType());
+ }
+ }
+
+ if (!needUpdate)
+ return failure();
+
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(term);
+ rewriter.replaceOpWithNewOp<ConditionOp>(term, term.condition(),
+ newTermArgs);
+ }
+
+ auto newWhile =
+ rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.inits());
+
+ Block &newAfterBlock = *rewriter.createBlock(
+ &newWhile.after(), /*insertPt*/ {}, newResultTypes);
+
+ // Build new results list and new after block args (unused entries will be
+ // null).
+ SmallVector<Value> newResults(op.getNumResults());
+ SmallVector<Value> newAfterBlockArgs(op.getNumResults());
+ for (auto it : llvm::enumerate(newResultsIndices)) {
+ newResults[it.value()] = newWhile.getResult(it.index());
+ newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
+ }
+
+ rewriter.inlineRegionBefore(op.before(), newWhile.before(),
+ newWhile.before().begin());
+
+ Block &afterBlock = op.after().front();
+ rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
+
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+};
} // namespace
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<WhileConditionTruth>(context);
+ results.insert<WhileConditionTruth, WhileUnusedResult>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c8471b8adbc60..6f1815713819c 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -782,7 +782,7 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// -----
// CHECK-LABEL: @while_cond_true
-func @while_cond_true() {
+func @while_cond_true() -> i1 {
%0 = scf.while () : () -> i1 {
%condition = "test.condition"() : () -> i1
scf.condition(%condition) %condition : i1
@@ -791,7 +791,7 @@ func @while_cond_true() {
"test.use"(%arg0) : (i1) -> ()
scf.yield
}
- return
+ return %0 : i1
}
// CHECK-NEXT: %[[true:.+]] = arith.constant true
// CHECK-NEXT: %{{.+}} = scf.while : () -> i1 {
@@ -805,6 +805,34 @@ func @while_cond_true() {
// -----
+// CHECK-LABEL: @while_unused_result
+func @while_unused_result() -> i32 {
+ %0:2 = scf.while () : () -> (i32, i64) {
+ %condition = "test.condition"() : () -> i1
+ %v1 = "test.get_some_value"() : () -> i32
+ %v2 = "test.get_some_value"() : () -> i64
+ scf.condition(%condition) %v1, %v2 : i32, i64
+ } do {
+ ^bb0(%arg0: i32, %arg1: i64):
+ "test.use"(%arg0) : (i32) -> ()
+ scf.yield
+ }
+ return %0#0 : i32
+}
+// CHECK-NEXT: %[[res:.*]] = scf.while : () -> i32 {
+// CHECK-NEXT: %[[cmp:.*]] = "test.condition"() : () -> i1
+// CHECK-NEXT: %[[val:.*]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: %{{.*}} = "test.get_some_value"() : () -> i64
+// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%[[arg:.*]]: i32): // no predecessors
+// CHECK-NEXT: "test.use"(%[[arg]]) : (i32) -> ()
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[res]] : i32
+
+// -----
+
// CHECK-LABEL: @combineIfs
func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
%res = scf.if %arg0 -> i32 {
More information about the Mlir-commits
mailing list