[Mlir-commits] [mlir] dbe96c8 - [MLIR][SCF] Combine nested ifs with yields
William S. Moses
llvmlistbot at llvm.org
Fri Mar 18 10:00:33 PDT 2022
Author: William S. Moses
Date: 2022-03-18T13:00:28-04:00
New Revision: dbe96c8da02141541d7fbde0775494235b98089a
URL: https://github.com/llvm/llvm-project/commit/dbe96c8da02141541d7fbde0775494235b98089a
DIFF: https://github.com/llvm/llvm-project/commit/dbe96c8da02141541d7fbde0775494235b98089a.diff
LOG: [MLIR][SCF] Combine nested ifs with yields
This patch extends the existing combine nested if
combination canonicalization to also handle ifs which
yield values
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D121923
Added:
Modified:
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index d888a341f0781..52004bea515b4 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1746,27 +1746,65 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
- // Both `if` ops must not yield results and have only `then` block.
- if (op->getNumResults() != 0 || op.elseBlock())
- return failure();
-
auto nestedOps = op.thenBlock()->without_terminator();
// Nested `if` must be the only op in block.
if (!llvm::hasSingleElement(nestedOps))
return failure();
+ // If there is an else block, it can only yield
+ if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
+ return failure();
+
auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
- if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock())
+ if (!nestedIf)
+ return failure();
+
+ if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
return failure();
+ SmallVector<Value> thenYield(op.thenYield().getOperands());
+ SmallVector<Value> elseYield;
+ if (op.elseBlock())
+ llvm::append_range(elseYield, op.elseYield().getOperands());
+
+ // If the outer scf.if yields a value produced by the inner scf.if,
+ // only permit combining if the value yielded when the condition
+ // is false in the outer scf.if is the same value yielded when the
+ // inner scf.if condition is false.
+ // Note that the array access to elseYield will not go out of bounds
+ // since it must have the same length as thenYield, since they both
+ // come from the same scf.if.
+ for (auto tup : llvm::enumerate(thenYield)) {
+ if (tup.value().getDefiningOp() == nestedIf) {
+ auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
+ if (nestedIf.elseYield().getOperand(nestedIdx) !=
+ elseYield[tup.index()]) {
+ return failure();
+ }
+ // If the correctness test passes, we will yield
+ // corresponding value from the inner scf.if
+ thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
+ }
+ }
+
Location loc = op.getLoc();
Value newCondition = rewriter.create<arith::AndIOp>(
loc, op.getCondition(), nestedIf.getCondition());
- auto newIf = rewriter.create<IfOp>(loc, newCondition);
+ auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
Block *newIfBlock = newIf.thenBlock();
- rewriter.eraseOp(newIfBlock->getTerminator());
+ if (newIfBlock)
+ rewriter.eraseOp(newIfBlock->getTerminator());
+ else
+ newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
- rewriter.eraseOp(op);
+ rewriter.setInsertionPointToEnd(newIf.thenBlock());
+ rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
+ if (!elseYield.empty()) {
+ rewriter.createBlock(&newIf.getElseRegion());
+ rewriter.setInsertionPointToEnd(newIf.elseBlock());
+ rewriter.create<YieldOp>(loc, elseYield);
+ }
+ rewriter.replaceOp(op, newIf.getResults());
return success();
}
};
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index f06e9833d1211..b69702331e741 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -491,6 +491,113 @@ func @merge_nested_if(%arg0: i1, %arg1: i1) {
// -----
+// CHECK-LABEL: @merge_yielding_nested_if
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
+func @merge_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
+// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
+// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32
+// CHECK: %[[PRE2:.*]] = "test.op2"() : () -> i32
+// CHECK: %[[PRE3:.*]] = "test.op3"() : () -> i8
+// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]]:2 = scf.if %[[COND]] -> (f32, i32)
+// CHECK: %[[IN0:.*]] = "test.inop"() : () -> i32
+// CHECK: %[[IN1:.*]] = "test.inop1"() : () -> f32
+// CHECK: scf.yield %[[IN1]], %[[IN0]] : f32, i32
+// CHECK: } else {
+// CHECK: scf.yield %[[PRE1]], %[[PRE2]] : f32, i32
+// CHECK: }
+// CHECK: return %[[PRE0]], %[[RES]]#0, %[[RES]]#1, %[[PRE3]] : i32, f32, i32, i8
+ %0 = "test.op"() : () -> (i32)
+ %1 = "test.op1"() : () -> (f32)
+ %2 = "test.op2"() : () -> (i32)
+ %3 = "test.op3"() : () -> (i8)
+ %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) {
+ %a:2 = scf.if %arg1 -> (i32, f32) {
+ %i = "test.inop"() : () -> (i32)
+ %i1 = "test.inop1"() : () -> (f32)
+ scf.yield %i, %i1 : i32, f32
+ } else {
+ scf.yield %2, %1 : i32, f32
+ }
+ scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8
+ } else {
+ scf.yield %0, %1, %2, %3 : i32, f32, i32, i8
+ }
+ return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
+}
+
+// CHECK-LABEL: @merge_yielding_nested_if_nv1
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
+func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
+// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
+// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32
+// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
+// CHECK: scf.if %[[COND]]
+// CHECK: %[[IN0:.*]] = "test.inop"() : () -> i32
+// CHECK: %[[IN1:.*]] = "test.inop1"() : () -> f32
+// CHECK: }
+ %0 = "test.op"() : () -> (i32)
+ %1 = "test.op1"() : () -> (f32)
+ scf.if %arg0 {
+ %a:2 = scf.if %arg1 -> (i32, f32) {
+ %i = "test.inop"() : () -> (i32)
+ %i1 = "test.inop1"() : () -> (f32)
+ scf.yield %i, %i1 : i32, f32
+ } else {
+ scf.yield %0, %1 : i32, f32
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: @merge_yielding_nested_if_nv2
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
+func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
+// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
+// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
+// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
+// CHECK: scf.if %[[COND]]
+// CHECK: "test.run"() : () -> ()
+// CHECK: }
+// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]]
+// CHECK: return %[[RES]]
+ %0 = "test.op"() : () -> (i32)
+ %1 = "test.op1"() : () -> (i32)
+ %r = scf.if %arg0 -> i32 {
+ scf.if %arg1 {
+ "test.run"() : () -> ()
+ }
+ scf.yield %0 : i32
+ } else {
+ scf.yield %1 : i32
+ }
+ return %r : i32
+}
+
+// CHECK-LABEL: @merge_fail_yielding_nested_if
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
+func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
+// CHECK-NOT: andi
+ %0 = "test.op"() : () -> (i32)
+ %1 = "test.op1"() : () -> (f32)
+ %2 = "test.op2"() : () -> (i32)
+ %3 = "test.op3"() : () -> (i8)
+ %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) {
+ %a:2 = scf.if %arg1 -> (i32, f32) {
+ %i = "test.inop"() : () -> (i32)
+ %i1 = "test.inop1"() : () -> (f32)
+ scf.yield %i, %i1 : i32, f32
+ } else {
+ scf.yield %0, %1 : i32, f32
+ }
+ scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8
+ } else {
+ scf.yield %0, %1, %2, %3 : i32, f32, i32, i8
+ }
+ return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
+}
+// -----
+
// CHECK-LABEL: func @if_condition_swap
// CHECK-NEXT: %{{.*}} = scf.if %arg0 -> (index) {
// CHECK-NEXT: %[[i1:.+]] = "test.origFalse"() : () -> index
More information about the Mlir-commits
mailing list