[Mlir-commits] [mlir] [MLIR][Vector] Fix WarpOpScfForOp and WarpOpScfIfOp leaving invalid ops after region moves (PR #188951)

Mehdi Amini llvmlistbot at llvm.org
Mon Apr 13 10:55:20 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/188951

>From d6f25e17937c51191e72b135b9fcc80dcbe42f89 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 15:57:49 -0700
Subject: [PATCH] [MLIR][Vector] Fix WarpOpScfForOp and WarpOpScfIfOp leaving
 invalid ops after region moves

WarpOpScfForOp::matchAndRewrite called mergeBlocks() to move forOp's body
block into the inner WarpOp. mergeBlocks() erases the source block, leaving
forOp with an empty body region (0 blocks). Since scf.for requires exactly
1 body block, IR verification fails with "region with 1 blocks" after the
pattern succeeds. Additionally, when forOp had no init args, the pattern was
missing the scf.yield terminator in the new ForOp.

WarpOpScfIfOp::matchAndRewrite had the same issue: takeBody() emptied the
ifOp's then/else regions, leaving scf.if with 0 blocks.

Fix:
- Restore the conditional scf.yield creation (only when newForOp has results).
- After merging/taking the regions, replace the remaining op's results with
  ub.poison and erase the now-invalid op from the new WarpOp's body.

Assisted-by: Claude Code
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
---
 .../Vector/Transforms/VectorDistribute.cpp    | 35 +++++++++++++++++++
 1 file changed, 35 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2e0e650f2bb9c..5d1aa1113db8d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
@@ -20,6 +21,7 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <utility>
@@ -2009,6 +2011,26 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
     for (auto [origIdx, newIdx] : ifResultMapping)
       rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
                                     newIfOp.getResult(newIdx), newIfOp);
+
+    // The original `ifOp` was left inside `newWarpOp` with empty then/else
+    // regions (their blocks were moved into the inner WarpOps by takeBody).
+    // Clear remaining uses and erase it to restore IR validity. Directly
+    // update newWarpOp's yield operands instead of using replaceAllUsesWith,
+    // to avoid triggering notifyOperandReplaced on the now-invalid ifOp.
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPoint(ifOp);
+      Operation *yield = newWarpOp.getTerminator();
+      rewriter.modifyOpInPlace(yield, [&]() {
+        for (auto [origIdx, ifResultIdx] : ifResultMapping) {
+          Value poison = ub::PoisonOp::create(
+              rewriter, ifOp.getLoc(), ifOp.getResult(ifResultIdx).getType());
+          yield->setOperand(origIdx, poison);
+        }
+      });
+      rewriter.eraseOp(ifOp);
+    }
+
     return success();
   }
 
@@ -2080,6 +2102,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     SmallVector<unsigned> nonForResultIndices;
     llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
     llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
+    llvm::SmallBitVector forResultsMapped(forOp.getNumResults());
     for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
       // Yielded value is not a result of the forOp.
       if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
@@ -2090,6 +2113,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       OpResult forResult = cast<OpResult>(yieldOperand.get());
       unsigned int forResultNumber = forResult.getResultNumber();
       forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
+      forResultsMapped.set(forResultNumber);
       // If this `ForOp` result is vector type and it is yielded by the
       // `WarpOp`, we keep track the distributed type for this result.
       if (!isa<VectorType>(forResult.getType()))
@@ -2224,6 +2248,17 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     for (auto [origIdx, newIdx] : forResultMapping)
       rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
                                     newForOp.getResult(newIdx), newForOp);
+
+    // The original `ForOp` was left inside `newWarpOp` with an empty body
+    // region (its body block was moved into `innerWarp` by `mergeBlocks`).
+    // Clear remaining uses and erase it to restore IR validity.
+    for (OpResult result : forOp.getResults()) {
+      if (forResultsMapped.test(result.getResultNumber()))
+        rewriter.replaceAllUsesWith(
+            result, forOp.getInitArgs()[result.getResultNumber()]);
+    }
+    rewriter.eraseOp(forOp);
+
     // Update any users of escaping values that were forwarded to the
     // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
     newForOp.walk([&](Operation *op) {



More information about the Mlir-commits mailing list