[Mlir-commits] [mlir] 9337788 - [mlir] add scf.if op canonicalization pattern that removes unused results
Tobias Gysi
llvmlistbot at llvm.org
Sun Oct 11 01:44:55 PDT 2020
Author: Tobias Gysi
Date: 2020-10-11T10:40:28+02:00
New Revision: 93377888ae89560ba6d3976e2762d3d4724c4dfd
URL: https://github.com/llvm/llvm-project/commit/93377888ae89560ba6d3976e2762d3d4724c4dfd
DIFF: https://github.com/llvm/llvm-project/commit/93377888ae89560ba6d3976e2762d3d4724c4dfd.diff
LOG: [mlir] add scf.if op canonicalization pattern that removes unused results
The patch adds a canonicalization pattern that removes the unused results of scf.if operation. As a result, cse may remove unused computations in the then and else regions of the scf.if operation.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D89029
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index d7ff8b6352bb..476898ab2072 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -262,6 +262,8 @@ def IfOp : SCF_Op<"if",
: OpBuilder::atBlockEnd(body, listener);
}
}];
+
+ let hasCanonicalizer = 1;
}
def ParallelOp : SCF_Op<"parallel",
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index e36ffc2e6b81..f25ccc454fbc 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -508,6 +508,67 @@ void IfOp::getSuccessorRegions(Optional<unsigned> index,
regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
}
+namespace {
+// Pattern to remove unused IfOp results.
+struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
+ PatternRewriter &rewriter) const {
+ // Move all operations to the destination block.
+ rewriter.mergeBlocks(source, dest);
+ // Replace the yield op by one that returns only the used values.
+ auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
+ SmallVector<Value, 4> usedOperands;
+ llvm::transform(usedResults, std::back_inserter(usedOperands),
+ [&](OpResult result) {
+ return yieldOp.getOperand(result.getResultNumber());
+ });
+ rewriter.updateRootInPlace(
+ yieldOp, [&]() { yieldOp.getOperation()->setOperands(usedOperands); });
+ }
+
+ LogicalResult matchAndRewrite(IfOp op,
+ PatternRewriter &rewriter) const override {
+ // Compute the list of used results.
+ SmallVector<OpResult, 4> usedResults;
+ llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
+ [](OpResult result) { return !result.use_empty(); });
+
+ // Replace the operation if only a subset of its results have uses.
+ if (usedResults.size() == op.getNumResults())
+ return failure();
+
+ // Compute the result types of the replacement operation.
+ SmallVector<Type, 4> newTypes;
+ llvm::transform(usedResults, std::back_inserter(newTypes),
+ [](OpResult result) { return result.getType(); });
+
+ // Create a replacement operation with empty then and else regions.
+ auto emptyBuilder = [](OpBuilder &, Location) {};
+ auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.condition(),
+ emptyBuilder, emptyBuilder);
+
+ // Move the bodies and replace the terminators (note there is a then and
+ // an else region since the operation returns results).
+ transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
+ transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
+
+ // Replace the operation by the new one.
+ SmallVector<Value, 4> repResults(op.getNumResults());
+ for (auto en : llvm::enumerate(usedResults))
+ repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
+ rewriter.replaceOp(op, repResults);
+ return success();
+ }
+};
+} // namespace
+
+void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<RemoveUnusedResults>(context);
+}
+
//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index fc98dabc0d2d..a96786076109 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -53,3 +53,87 @@ func @no_iteration(%A: memref<?x?xi32>) {
// CHECK: scf.yield
// CHECK: }
// CHECK: return
+
+// -----
+
+func @one_unused() -> (index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %true = constant true
+ %0, %1 = scf.if %true -> (index, index) {
+ scf.yield %c0, %c1 : index, index
+ } else {
+ scf.yield %c0, %c1 : index, index
+ }
+ return %1 : index
+}
+
+// CHECK-LABEL: func @one_unused
+// CHECK: [[C0:%.*]] = constant 1 : index
+// CHECK: [[C1:%.*]] = constant true
+// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) {
+// CHECK: scf.yield [[C0]] : index
+// CHECK: } else
+// CHECK: scf.yield [[C0]] : index
+// CHECK: }
+// CHECK: return [[V0]] : index
+
+// -----
+
+func @nested_unused() -> (index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %true = constant true
+ %0, %1 = scf.if %true -> (index, index) {
+ %2, %3 = scf.if %true -> (index, index) {
+ scf.yield %c0, %c1 : index, index
+ } else {
+ scf.yield %c0, %c1 : index, index
+ }
+ scf.yield %2, %3 : index, index
+ } else {
+ scf.yield %c0, %c1 : index, index
+ }
+ return %1 : index
+}
+
+// CHECK-LABEL: func @nested_unused
+// CHECK: [[C0:%.*]] = constant 1 : index
+// CHECK: [[C1:%.*]] = constant true
+// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) {
+// CHECK: [[V1:%.*]] = scf.if [[C1]] -> (index) {
+// CHECK: scf.yield [[C0]] : index
+// CHECK: } else
+// CHECK: scf.yield [[C0]] : index
+// CHECK: }
+// CHECK: scf.yield [[V1]] : index
+// CHECK: } else
+// CHECK: scf.yield [[C0]] : index
+// CHECK: }
+// CHECK: return [[V0]] : index
+
+// -----
+
+func @side_effect() {}
+func @all_unused() {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %true = constant true
+ %0, %1 = scf.if %true -> (index, index) {
+ call @side_effect() : () -> ()
+ scf.yield %c0, %c1 : index, index
+ } else {
+ call @side_effect() : () -> ()
+ scf.yield %c0, %c1 : index, index
+ }
+ return
+}
+
+// CHECK-LABEL: func @all_unused
+// CHECK: [[C1:%.*]] = constant true
+// CHECK: scf.if [[C1]] {
+// CHECK: call @side_effect() : () -> ()
+// CHECK: } else
+// CHECK: call @side_effect() : () -> ()
+// CHECK: }
+// CHECK: return
More information about the Mlir-commits
mailing list