[Mlir-commits] [mlir] [MLIR][Vector] Fix WarpOpScfForOp and WarpOpScfIfOp leaving invalid ops after region moves (PR #188951)

Mehdi Amini llvmlistbot at llvm.org
Wed Apr 15 09:13:36 PDT 2026


================
@@ -2009,6 +2011,26 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
     for (auto [origIdx, newIdx] : ifResultMapping)
       rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
                                     newIfOp.getResult(newIdx), newIfOp);
+
+    // The original `ifOp` was left inside `newWarpOp` with empty then/else
+    // regions (their blocks were moved into the inner WarpOps by takeBody).
+    // Clear remaining uses and erase it to restore IR validity. Directly
+    // update newWarpOp's yield operands instead of using replaceAllUsesWith,
+    // to avoid triggering notifyOperandReplaced on the now-invalid ifOp.
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPoint(ifOp);
+      Operation *yield = newWarpOp.getTerminator();
+      rewriter.modifyOpInPlace(yield, [&]() {
+        for (auto [origIdx, ifResultIdx] : ifResultMapping) {
+          Value poison = ub::PoisonOp::create(
+              rewriter, ifOp.getLoc(), ifOp.getResult(ifResultIdx).getType());
+          yield->setOperand(origIdx, poison);
+        }
+      });
+      rewriter.eraseOp(ifOp);
+    }
----------------
joker-eph wrote:

Here is the state of the IR at this point, we go from:

```
%1 = gpu.warp_execute_on_lane_0(%0)[32] -> (vector<1xf32>) {
  %2 = vector.step : vector<32xindex>
  %3 = scf.if %arg0 -> (vector<32xf32>) {
    %4 = "some_op"(%2) : (vector<32xindex>) -> vector<32xf32>
    scf.yield %4 : vector<32xf32>
  } else {
    %4 = "other_op"(%cst) : (vector<32xindex>) -> vector<32xf32>
    scf.yield %4 : vector<32xf32>
  }
  gpu.yield %3 : vector<32xf32>
}
```

To:

```
%3:3 = "gpu.warp_execute_on_lane_0"(%2) <{warp_size = 32 : i64}> ({
  %9 = "vector.step"() : () -> vector<32xindex>
  %10 = "scf.if"(%arg0) ({
  }, {
  }) : (i1) -> vector<32xf32>
  "gpu.yield"(%10, %arg0, %9) : (vector<32xf32>, i1, vector<32xindex>) -> ()
}) : (index) -> (vector<1xf32>, i1, vector<1xindex>)
```



https://github.com/llvm/llvm-project/pull/188951


More information about the Mlir-commits mailing list