[Mlir-commits] [mlir] f4a2dbf - [MLIR][SCF] Combine adjacent scf.if with same condition
William S. Moses
llvmlistbot at llvm.org
Tue May 4 21:40:15 PDT 2021
Author: William S. Moses
Date: 2021-05-05T00:39:58-04:00
New Revision: f4a2dbfe29031f02c02d6045159f22785dd611cf
URL: https://github.com/llvm/llvm-project/commit/f4a2dbfe29031f02c02d6045159f22785dd611cf
DIFF: https://github.com/llvm/llvm-project/commit/f4a2dbfe29031f02c02d6045159f22785dd611cf.diff
LOG: [MLIR][SCF] Combine adjacent scf.if with same condition
Differential Revision: https://reviews.llvm.org/D101798
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index c3c64e04ae08..fcaba0e6f2b2 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -311,6 +311,10 @@ def IfOp : SCF_Op<"if",
return results().empty() ? OpBuilder::atBlockTerminator(body, listener)
: OpBuilder::atBlockEnd(body, listener);
}
+ Block* thenBlock();
+ YieldOp thenYield();
+ Block* elseBlock();
+ YieldOp elseYield();
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index c28e438fc819..91f1e7a3e7c0 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1266,15 +1266,125 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
}
};
+/// Merge any consecutive scf.if's with the same condition.
+///
+/// scf.if %cond {
+/// firstCodeTrue();...
+/// } else {
+/// firstCodeFalse();...
+/// }
+/// %res = scf.if %cond {
+/// secondCodeTrue();...
+/// } else {
+/// secondCodeFalse();...
+/// }
+///
+/// becomes
+/// %res = scf.if %cmp {
+/// firstCodeTrue();...
+/// secondCodeTrue();...
+/// } else {
+/// firstCodeFalse();...
+/// secondCodeFalse();...
+/// }
+struct CombineIfs : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IfOp nextIf,
+ PatternRewriter &rewriter) const override {
+ Block *parent = nextIf->getBlock();
+ if (nextIf == &parent->front())
+ return failure();
+
+ auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
+ if (!prevIf)
+ return failure();
+
+ if (nextIf.condition() != prevIf.condition())
+ return failure();
+
+ // Don't permit merging if a result of the first if is used
+ // within the second.
+ if (llvm::any_of(prevIf->getUsers(),
+ [&](Operation *user) { return nextIf->isAncestor(user); }))
+ return failure();
+
+ SmallVector<Type> mergedTypes(prevIf.getResultTypes());
+ llvm::append_range(mergedTypes, nextIf.getResultTypes());
+
+ IfOp combinedIf = rewriter.create<IfOp>(
+ nextIf.getLoc(), mergedTypes, nextIf.condition(), /*hasElse=*/false);
+ rewriter.eraseBlock(&combinedIf.thenRegion().back());
+
+ YieldOp thenYield = prevIf.thenYield();
+ YieldOp thenYield2 = nextIf.thenYield();
+
+ combinedIf.thenRegion().getBlocks().splice(
+ combinedIf.thenRegion().getBlocks().begin(),
+ prevIf.thenRegion().getBlocks());
+
+ rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock());
+ rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
+
+ SmallVector<Value> mergedYields(thenYield.getOperands());
+ llvm::append_range(mergedYields, thenYield2.getOperands());
+ rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
+ rewriter.eraseOp(thenYield);
+ rewriter.eraseOp(thenYield2);
+
+ combinedIf.elseRegion().getBlocks().splice(
+ combinedIf.elseRegion().getBlocks().begin(),
+ prevIf.elseRegion().getBlocks());
+
+ if (!nextIf.elseRegion().empty()) {
+ if (combinedIf.elseRegion().empty()) {
+ combinedIf.elseRegion().getBlocks().splice(
+ combinedIf.elseRegion().getBlocks().begin(),
+ nextIf.elseRegion().getBlocks());
+ } else {
+ YieldOp elseYield = combinedIf.elseYield();
+ YieldOp elseYield2 = nextIf.elseYield();
+ rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock());
+
+ rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
+
+ SmallVector<Value> mergedElseYields(elseYield.getOperands());
+ llvm::append_range(mergedElseYields, elseYield2.getOperands());
+
+ rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
+ rewriter.eraseOp(elseYield);
+ rewriter.eraseOp(elseYield2);
+ }
+ }
+
+ SmallVector<Value> prevValues;
+ SmallVector<Value> nextValues;
+ for (auto pair : llvm::enumerate(combinedIf.getResults())) {
+ if (pair.index() < prevIf.getNumResults())
+ prevValues.push_back(pair.value());
+ else
+ nextValues.push_back(pair.value());
+ }
+ rewriter.replaceOp(prevIf, prevValues);
+ rewriter.replaceOp(nextIf, nextValues);
+ return success();
+ }
+};
+
} // namespace
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
- ConditionPropagation, ReplaceIfYieldWithConditionOrValue>(context);
+ results.add<RemoveUnusedResults, RemoveStaticCondition,
+ ConvertTrivialIfToSelect, ConditionPropagation,
+ ReplaceIfYieldWithConditionOrValue, CombineIfs>(context);
}
+Block *IfOp::thenBlock() { return &thenRegion().back(); }
+YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
+Block *IfOp::elseBlock() { return &elseRegion().back(); }
+YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
+
//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 3ba8e8023155..8f12c90b7729 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -747,3 +747,104 @@ func @while_cond_true() {
// CHECK-NEXT: "test.use"(%[[true]]) : (i1) -> ()
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
+
+// -----
+
+// CHECK-LABEL: @combineIfs
+func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
+ %res = scf.if %arg0 -> i32 {
+ %v = "test.firstCodeTrue"() : () -> i32
+ scf.yield %v : i32
+ } else {
+ %v2 = "test.firstCodeFalse"() : () -> i32
+ scf.yield %v2 : i32
+ }
+ %res2 = scf.if %arg0 -> i32 {
+ %v = "test.secondCodeTrue"() : () -> i32
+ scf.yield %v : i32
+ } else {
+ %v2 = "test.secondCodeFalse"() : () -> i32
+ scf.yield %v2 : i32
+ }
+ return %res, %res2 : i32, i32
+}
+// CHECK-NEXT: %[[res:.+]]:2 = scf.if %arg0 -> (i32, i32) {
+// CHECK-NEXT: %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32
+// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32
+// CHECK-NEXT: scf.yield %[[tval0]], %[[tval]] : i32, i32
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32
+// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32
+// CHECK-NEXT: scf.yield %[[fval0]], %[[fval]] : i32, i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32
+
+
+// CHECK-LABEL: @combineIfs2
+func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 {
+ scf.if %arg0 {
+ "test.firstCodeTrue"() : () -> ()
+ scf.yield
+ }
+ %res = scf.if %arg0 -> i32 {
+ %v = "test.secondCodeTrue"() : () -> i32
+ scf.yield %v : i32
+ } else {
+ %v2 = "test.secondCodeFalse"() : () -> i32
+ scf.yield %v2 : i32
+ }
+ return %res : i32
+}
+// CHECK-NEXT: %[[res:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT: "test.firstCodeTrue"() : () -> ()
+// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32
+// CHECK-NEXT: scf.yield %[[tval]] : i32
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32
+// CHECK-NEXT: scf.yield %[[fval]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[res]] : i32
+
+
+// CHECK-LABEL: @combineIfs3
+func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 {
+ %res = scf.if %arg0 -> i32 {
+ %v = "test.firstCodeTrue"() : () -> i32
+ scf.yield %v : i32
+ } else {
+ %v2 = "test.firstCodeFalse"() : () -> i32
+ scf.yield %v2 : i32
+ }
+ scf.if %arg0 {
+ "test.secondCodeTrue"() : () -> ()
+ scf.yield
+ }
+ return %res : i32
+}
+// CHECK-NEXT: %[[res:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT: %[[tval:.+]] = "test.firstCodeTrue"() : () -> i32
+// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
+// CHECK-NEXT: scf.yield %[[tval]] : i32
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[fval:.+]] = "test.firstCodeFalse"() : () -> i32
+// CHECK-NEXT: scf.yield %[[fval]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[res]] : i32
+
+// CHECK-LABEL: @combineIfs4
+func @combineIfs4(%arg0 : i1, %arg2: i64) {
+ scf.if %arg0 {
+ "test.firstCodeTrue"() : () -> ()
+ scf.yield
+ }
+ scf.if %arg0 {
+ "test.secondCodeTrue"() : () -> ()
+ scf.yield
+ }
+ return
+}
+
+// CHECK-NEXT: scf.if %arg0 {
+// CHECK-NEXT: "test.firstCodeTrue"() : () -> ()
+// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
+// CHECK-NEXT: }
More information about the Mlir-commits
mailing list