[Mlir-commits] [mlir] 195de3d - [MLIR][SCF] Fix nested if merging bug

William S. Moses llvmlistbot at llvm.org
Mon Mar 21 08:42:32 PDT 2022


Author: William S. Moses
Date: 2022-03-21T11:42:26-04:00
New Revision: 195de3dd6c86f01956f2d1f87b2b7dd25f8c0aed

URL: https://github.com/llvm/llvm-project/commit/195de3dd6c86f01956f2d1f87b2b7dd25f8c0aed
DIFF: https://github.com/llvm/llvm-project/commit/195de3dd6c86f01956f2d1f87b2b7dd25f8c0aed.diff

LOG: [MLIR][SCF] Fix nested if merging bug

The current nested if merging has a bug. Specifically, consider the following code:

```
    %r = scf.if %arg3 -> (i32) {
      scf.if %arg1 {
        "test.op"() : () -> ()
      }
      scf.yield %arg0 : i32
    } else {
      scf.yield %arg2 : i32
    }
```

When the above gets merged, it will become:
```
    %r = scf.if %arg3 && %arg1-> (i32) {
      "test.op"() : () -> ()
      scf.yield %arg0 : i32
    } else {
      scf.yield %arg2 : i32
    }
```

However, this means that when only %arg3 is true, we will incorrectly return %arg2 instead
of %arg0. This change updates the behavior of the pass to only enable nested if merging where
the outer yield contains only values from the inner if, or values defined outside of the if.

In the case of the latter, they can turned into a select of only the outer if condition, thus
maintaining correctness.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D122108

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index c36e4ccb82fb0..114d93bf445bf 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1768,6 +1768,10 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
     if (op.elseBlock())
       llvm::append_range(elseYield, op.elseYield().getOperands());
 
+    // A list of indices for which we should upgrade the value yielded
+    // in the else to a select.
+    SmallVector<unsigned> elseYieldsToUpgradeToSelect;
+
     // If the outer scf.if yields a value produced by the inner scf.if,
     // only permit combining if the value yielded when the condition
     // is false in the outer scf.if is the same value yielded when the
@@ -1785,6 +1789,22 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
         // If the correctness test passes, we will yield
         // corresponding value from the inner scf.if
         thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
+        continue;
+      }
+
+      // Otherwise, we need to ensure the else block of the combined
+      // condition still returns the same value when the outer condition is
+      // true and the inner condition is false. This can be accomplished if
+      // the then value is defined outside the outer scf.if and we replace the
+      // value with a select that considers just the outer condition. Since
+      // the else region contains just the yield, its yielded value is
+      // defined outside the scf.if, by definition.
+
+      // If the then value is defined within the scf.if, bail.
+      if (tup.value().getParentRegion() == &op.getThenRegion()) {
+        return failure();
+      } else {
+        elseYieldsToUpgradeToSelect.push_back(tup.index());
       }
     }
 
@@ -1792,6 +1812,15 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
     Value newCondition = rewriter.create<arith::AndIOp>(
         loc, op.getCondition(), nestedIf.getCondition());
     auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
+
+    SmallVector<Value> results;
+    llvm::append_range(results, newIf.getResults());
+    rewriter.setInsertionPoint(newIf);
+
+    for (auto idx : elseYieldsToUpgradeToSelect)
+      results[idx] = rewriter.create<arith::SelectOp>(
+          op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
+
     Block *newIfBlock = newIf.thenBlock();
     if (newIfBlock)
       rewriter.eraseOp(newIfBlock->getTerminator());
@@ -1805,7 +1834,7 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
       rewriter.setInsertionPointToEnd(newIf.elseBlock());
       rewriter.create<YieldOp>(loc, elseYield);
     }
-    rewriter.replaceOp(op, newIf.getResults());
+    rewriter.replaceOp(op, results);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c9c2b08c211ee..de176fb588d87 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -556,7 +556,7 @@ func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
 // CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
 // CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
 // CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
-// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]]
+// CHECK: %[[RES:.*]] = arith.select %[[ARG0]], %[[PRE0]], %[[PRE1]]
 // CHECK: scf.if %[[COND]] 
 // CHECK:   "test.run"() : () -> ()
 // CHECK: }
@@ -596,6 +596,7 @@ func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8)
   }
   return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
 }
+
 // -----
 
 // CHECK-LABEL:   func @if_condition_swap


        


More information about the Mlir-commits mailing list