[Mlir-commits] [mlir] MLIR][Vector] Improve warp distribution robustness (PR #161647)

Artem Kroviakov llvmlistbot at llvm.org
Thu Oct 2 03:06:31 PDT 2025


https://github.com/akroviakov created https://github.com/llvm/llvm-project/pull/161647

This PR improves the warp distribution robustness by:
1. Ensuring that during the warp result deduplication, results with no uses are not mapped to a non-existing index. Currently we map to `newResultTypes.size()`, but may opt out of inserting to it, leading to a later OOB error.
2. Simplifying the `scf.if` and `scf.for` handling through the usage of `moveRegionToNewWarpOpAndAppendReturns`, which also performs warp result deduplication in-place. This allows avoiding cases where, for example, after sinking two `scf.if` that need the same escaping value, a _higher-ranked_ sink-pattern tries to lower the escaping value producer (which is yielded twice at this point) prior to `WarpOpDeadResult` actually deduplicates the result, leading to sinking the same op twice (once per yield operand). 

>From 04ee917689808ed26f5119ab77fab6b63dbcf94b Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 2 Oct 2025 09:45:32 +0000
Subject: [PATCH] MLIR][Vector] Improve warp distribution robustness

---
 .../Vector/Transforms/VectorDistribute.cpp    | 59 +++++++------------
 .../Vector/vector-warp-distribute.mlir        | 19 ++++++
 2 files changed, 39 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f7d18be..47aa1ca40fb03 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -934,11 +934,13 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
     //   3. skipping from the new result types / new yielded values any result
     //      that has no use or whose yielded value has already been seen.
     for (OpResult result : warpOp.getResults()) {
+      if (result.use_empty())
+        continue;
       Value yieldOperand = yield.getOperand(result.getResultNumber());
       auto it = dedupYieldOperandPositionMap.insert(
           std::make_pair(yieldOperand, newResultTypes.size()));
       dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
-      if (result.use_empty() || !it.second)
+      if (!it.second)
         continue;
       newResultTypes.push_back(result.getType());
       newYieldValues.push_back(yieldOperand);
@@ -1843,16 +1845,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
     newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
                               escapingValueDistTypesElse.end());
 
-    llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
     for (auto [idx, val] :
          llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
-      origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
       newWarpOpYieldValues.push_back(val);
       newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
     }
-    // Create the new `WarpOp` with the updated yield values and types.
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    // Replace the old `WarpOp` with the new one that has additional yield
+    // values and types.
+    SmallVector<size_t> newIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
     // `ifOp` returns the result of the inner warp op.
     SmallVector<Type> newIfOpDistResTypes;
     for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
@@ -1870,8 +1872,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
     auto newIfOp = scf::IfOp::create(
-        rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
-        static_cast<bool>(ifOp.thenBlock()),
+        rewriter, ifOp.getLoc(), newIfOpDistResTypes,
+        newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
         static_cast<bool>(ifOp.elseBlock()));
     auto encloseRegionInWarpOp =
         [&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1890,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
           for (size_t i = 0; i < escapingValues.size();
                ++i, ++warpResRangeStart) {
             innerWarpInputVals.push_back(
-                newWarpOp.getResult(warpResRangeStart));
+                newWarpOp.getResult(newIndices[warpResRangeStart]));
             escapeValToBlockArgIndex[escapingValues[i]] =
                 innerWarpInputTypes.size();
             innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
@@ -1936,17 +1938,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
     // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
     // result.
     for (auto [origIdx, newIdx] : ifResultMapping)
-      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+      rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
                                     newIfOp.getResult(newIdx), newIfOp);
-    // Similarly, update any users of the `WarpOp` results that were not
-    // results of the `IfOp`.
-    for (auto [origIdx, newIdx] : origToNewYieldIdx)
-      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
-                                  newWarpOp.getResult(newIdx));
-    // Remove the original `WarpOp` and `IfOp`, they should not have any uses
-    // at this point.
-    rewriter.eraseOp(ifOp);
-    rewriter.eraseOp(warpOp);
     return success();
   }
 
@@ -2065,19 +2058,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                               escapingValueDistTypes.begin(),
                               escapingValueDistTypes.end());
     // Next, we insert all non-`ForOp` yielded values and their distributed
-    // types. We also create a mapping between the non-`ForOp` yielded value
-    // index and the corresponding new `WarpOp` yield value index (needed to
-    // update users later).
-    llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
+    // types.
     for (auto [i, v] :
          llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
-      nonForResultMapping[i] = newWarpOpYieldValues.size();
       newWarpOpYieldValues.push_back(v);
       newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
     }
     // Create the new `WarpOp` with the updated yield values and types.
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    SmallVector<size_t> newIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
 
     // Next, we create a new `ForOp` with the init args yielded by the new
     // `WarpOp`.
@@ -2086,7 +2076,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                     // escaping values in the new `WarpOp`.
     SmallVector<Value> newForOpOperands;
     for (size_t i = 0; i < escapingValuesStartIdx; ++i)
-      newForOpOperands.push_back(newWarpOp.getResult(i));
+      newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
 
     // Create a new `ForOp` outside the new `WarpOp` region.
     OpBuilder::InsertionGuard g(rewriter);
@@ -2110,7 +2100,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
     for (size_t i = escapingValuesStartIdx;
          i < escapingValuesStartIdx + escapingValues.size(); ++i) {
-      innerWarpInput.push_back(newWarpOp.getResult(i));
+      innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
       argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
           innerWarpInputType.size();
       innerWarpInputType.push_back(
@@ -2146,20 +2136,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     if (!innerWarp.getResults().empty())
       scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
 
-    // Update the users of original `WarpOp` results that were coming from the
+    // Update the users of the new `WarpOp` results that were coming from the
     // original `ForOp` to the corresponding new `ForOp` result.
     for (auto [origIdx, newIdx] : forResultMapping)
-      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+      rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
                                     newForOp.getResult(newIdx), newForOp);
-    // Similarly, update any users of the `WarpOp` results that were not
-    // results of the `ForOp`.
-    for (auto [origIdx, newIdx] : nonForResultMapping)
-      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
-                                  newWarpOp.getResult(newIdx));
-    // Remove the original `WarpOp` and `ForOp`, they should not have any uses
-    // at this point.
-    rewriter.eraseOp(forOp);
-    rewriter.eraseOp(warpOp);
     // Update any users of escaping values that were forwarded to the
     // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
     newForOp.walk([&](Operation *op) {
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bb7639204022f..401cdd29b281c 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1925,3 +1925,22 @@ func.func @warp_scf_if_distribute(%pred : i1)  {
 //       CHECK-PROP:    "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
 //       CHECK-PROP:    return
 //       CHECK-PROP:  }
+
+// -----
+func.func @dedup_unused_result(%laneid : index) -> (vector<1xf32>) {
+  %r:3 = gpu.warp_execute_on_lane_0(%laneid)[32] ->
+    (vector<1xf32>, vector<2xf32>, vector<1xf32>) {
+    %2 = "some_def"() : () -> (vector<32xf32>)
+    %3 = "some_def"() : () -> (vector<64xf32>)
+    gpu.yield %2, %3, %2 : vector<32xf32>, vector<64xf32>, vector<32xf32>
+  }
+  %r0 = "some_use"(%r#2, %r#2) : (vector<1xf32>, vector<1xf32>) -> (vector<1xf32>)
+  return %r0 : vector<1xf32>
+}
+
+// CHECK-PROP: func @dedup_unused_result
+// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>)
+// CHECK-PROP:   %[[Y0:.*]] = "some_def"() : () -> vector<32xf32>
+// CHECK-PROP:   %[[Y1:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP:   gpu.yield %[[Y0]] : vector<32xf32>
+// CHECK-PROP: "some_use"(%[[R]], %[[R]]) : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>



More information about the Mlir-commits mailing list