[Mlir-commits] [mlir] dbe96c8 - [MLIR][SCF] Combine nested ifs with yields

William S. Moses llvmlistbot at llvm.org
Fri Mar 18 10:00:33 PDT 2022


Author: William S. Moses
Date: 2022-03-18T13:00:28-04:00
New Revision: dbe96c8da02141541d7fbde0775494235b98089a

URL: https://github.com/llvm/llvm-project/commit/dbe96c8da02141541d7fbde0775494235b98089a
DIFF: https://github.com/llvm/llvm-project/commit/dbe96c8da02141541d7fbde0775494235b98089a.diff

LOG: [MLIR][SCF] Combine nested ifs with yields

This patch extends the existing combine nested if
combination canonicalization to also handle ifs which
yield values

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D121923

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index d888a341f0781..52004bea515b4 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1746,27 +1746,65 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
 
   LogicalResult matchAndRewrite(IfOp op,
                                 PatternRewriter &rewriter) const override {
-    // Both `if` ops must not yield results and have only `then` block.
-    if (op->getNumResults() != 0 || op.elseBlock())
-      return failure();
-
     auto nestedOps = op.thenBlock()->without_terminator();
     // Nested `if` must be the only op in block.
     if (!llvm::hasSingleElement(nestedOps))
       return failure();
 
+    // If there is an else block, it can only yield
+    if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
+      return failure();
+
     auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
-    if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock())
+    if (!nestedIf)
+      return failure();
+
+    if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
       return failure();
 
+    SmallVector<Value> thenYield(op.thenYield().getOperands());
+    SmallVector<Value> elseYield;
+    if (op.elseBlock())
+      llvm::append_range(elseYield, op.elseYield().getOperands());
+
+    // If the outer scf.if yields a value produced by the inner scf.if,
+    // only permit combining if the value yielded when the condition
+    // is false in the outer scf.if is the same value yielded when the
+    // inner scf.if condition is false.
+    // Note that the array access to elseYield will not go out of bounds
+    // since it must have the same length as thenYield, since they both
+    // come from the same scf.if.
+    for (auto tup : llvm::enumerate(thenYield)) {
+      if (tup.value().getDefiningOp() == nestedIf) {
+        auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
+        if (nestedIf.elseYield().getOperand(nestedIdx) !=
+            elseYield[tup.index()]) {
+          return failure();
+        }
+        // If the correctness test passes, we will yield
+        // corresponding value from the inner scf.if
+        thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
+      }
+    }
+
     Location loc = op.getLoc();
     Value newCondition = rewriter.create<arith::AndIOp>(
         loc, op.getCondition(), nestedIf.getCondition());
-    auto newIf = rewriter.create<IfOp>(loc, newCondition);
+    auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
     Block *newIfBlock = newIf.thenBlock();
-    rewriter.eraseOp(newIfBlock->getTerminator());
+    if (newIfBlock)
+      rewriter.eraseOp(newIfBlock->getTerminator());
+    else
+      newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
     rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
-    rewriter.eraseOp(op);
+    rewriter.setInsertionPointToEnd(newIf.thenBlock());
+    rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
+    if (!elseYield.empty()) {
+      rewriter.createBlock(&newIf.getElseRegion());
+      rewriter.setInsertionPointToEnd(newIf.elseBlock());
+      rewriter.create<YieldOp>(loc, elseYield);
+    }
+    rewriter.replaceOp(op, newIf.getResults());
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index f06e9833d1211..b69702331e741 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -491,6 +491,113 @@ func @merge_nested_if(%arg0: i1, %arg1: i1) {
 
 // -----
 
+// CHECK-LABEL: @merge_yielding_nested_if
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
+func @merge_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
+// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
+// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32
+// CHECK: %[[PRE2:.*]] = "test.op2"() : () -> i32
+// CHECK: %[[PRE3:.*]] = "test.op3"() : () -> i8
+// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]]:2 = scf.if %[[COND]] -> (f32, i32)
+// CHECK:   %[[IN0:.*]] = "test.inop"() : () -> i32
+// CHECK:   %[[IN1:.*]] = "test.inop1"() : () -> f32
+// CHECK:   scf.yield %[[IN1]], %[[IN0]] : f32, i32
+// CHECK: } else {
+// CHECK:   scf.yield %[[PRE1]], %[[PRE2]] : f32, i32
+// CHECK: }
+// CHECK: return %[[PRE0]], %[[RES]]#0, %[[RES]]#1, %[[PRE3]] : i32, f32, i32, i8
+  %0 = "test.op"() : () -> (i32)
+  %1 = "test.op1"() : () -> (f32)
+  %2 = "test.op2"() : () -> (i32)
+  %3 = "test.op3"() : () -> (i8)
+  %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) {
+    %a:2 = scf.if %arg1 -> (i32, f32) {
+      %i = "test.inop"() : () -> (i32)
+      %i1 = "test.inop1"() : () -> (f32)
+      scf.yield %i, %i1 : i32, f32
+    } else {
+      scf.yield %2, %1 : i32, f32
+    }
+    scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8
+  } else {
+    scf.yield %0, %1, %2, %3 : i32, f32, i32, i8
+  }
+  return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
+}
+
+// CHECK-LABEL: @merge_yielding_nested_if_nv1
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
+func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
+// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
+// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32
+// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
+// CHECK: scf.if %[[COND]]
+// CHECK:   %[[IN0:.*]] = "test.inop"() : () -> i32
+// CHECK:   %[[IN1:.*]] = "test.inop1"() : () -> f32
+// CHECK: }
+  %0 = "test.op"() : () -> (i32)
+  %1 = "test.op1"() : () -> (f32)
+  scf.if %arg0 {
+    %a:2 = scf.if %arg1 -> (i32, f32) {
+      %i = "test.inop"() : () -> (i32)
+      %i1 = "test.inop1"() : () -> (f32)
+      scf.yield %i, %i1 : i32, f32
+    } else {
+      scf.yield %0, %1 : i32, f32
+    }
+  }
+  return 
+}
+
+// CHECK-LABEL: @merge_yielding_nested_if_nv2
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
+func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
+// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
+// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
+// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
+// CHECK: scf.if %[[COND]] 
+// CHECK:   "test.run"() : () -> ()
+// CHECK: }
+// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]]
+// CHECK: return %[[RES]]
+  %0 = "test.op"() : () -> (i32)
+  %1 = "test.op1"() : () -> (i32)
+  %r = scf.if %arg0 -> i32 {
+    scf.if %arg1 {
+      "test.run"() : () -> ()
+    }
+    scf.yield %0 : i32
+  } else {
+    scf.yield %1 : i32
+  }
+  return %r : i32
+}
+
+// CHECK-LABEL: @merge_fail_yielding_nested_if
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
+func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
+// CHECK-NOT: andi
+  %0 = "test.op"() : () -> (i32)
+  %1 = "test.op1"() : () -> (f32)
+  %2 = "test.op2"() : () -> (i32)
+  %3 = "test.op3"() : () -> (i8)
+  %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) {
+    %a:2 = scf.if %arg1 -> (i32, f32) {
+      %i = "test.inop"() : () -> (i32)
+      %i1 = "test.inop1"() : () -> (f32)
+      scf.yield %i, %i1 : i32, f32
+    } else {
+      scf.yield %0, %1 : i32, f32
+    }
+    scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8
+  } else {
+    scf.yield %0, %1, %2, %3 : i32, f32, i32, i8
+  }
+  return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
+}
+// -----
+
 // CHECK-LABEL:   func @if_condition_swap
 // CHECK-NEXT:     %{{.*}} = scf.if %arg0 -> (index) {
 // CHECK-NEXT:       %[[i1:.+]] = "test.origFalse"() : () -> index


        


More information about the Mlir-commits mailing list