[Mlir-commits] [mlir] [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (PR #147620)

Charitha Saumya llvmlistbot at llvm.org
Fri Jul 11 12:47:52 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/147620

>From 586839035dfaaf45d66ad6b1184f94465c10906f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 2 Jul 2025 17:29:08 +0000
Subject: [PATCH 01/10] working but bug in dead result

---
 .../Vector/Transforms/VectorDistribute.cpp    | 66 ++++++++++++++-----
 1 file changed, 50 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index af90ed8f5deaf..28c957bf61921 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
@@ -1777,24 +1778,42 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     if (llvm::is_contained(distTypes, Type{}))
       return failure();
 
+    llvm::errs() << "escpaing values size: " << escapingValues.size() << "\n";
+
+    SmallVector<Value> yieldedValuesFromWarpOp;
+    // All init args of the forOp are yielded from the original warp op.
+    for (Value initArg : forOp.getInitArgs()) {
+      yieldedValuesFromWarpOp.push_back(initArg);
+      // find distributed type for the init arg.
+      Type distType = initArg.getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(initArg);
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      distTypes.push_back(distType);
+    }
+    // All escaping values are yielded from the original warp op.
+    yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
+                                   escapingValues.begin(),
+                                   escapingValues.end());
+
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
-        newRetIndices);
+        rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices);
     yield = cast<gpu::YieldOp>(
         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
 
     SmallVector<Value> newOperands;
     SmallVector<unsigned> resultIdx;
-    // Collect all the outputs coming from the forOp.
+    // Collect the new init args coming from the new warp op.
+    for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
+      newOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
     for (OpOperand &yieldOperand : yield->getOpOperands()) {
       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
         continue;
-      auto forResult = cast<OpResult>(yieldOperand.get());
-      newOperands.push_back(
-          newWarpOp.getResult(yieldOperand.getOperandNumber()));
+      OpResult forResult = cast<OpResult>(yieldOperand.get());
+      resultIdx.push_back(forResult.getResultNumber());
       yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
-      resultIdx.push_back(yieldOperand.getOperandNumber());
     }
 
     OpBuilder::InsertionGuard g(rewriter);
@@ -1812,8 +1831,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
                                     forOp.getResultTypes().end());
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
-    for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
-      warpInput.push_back(newWarpOp.getResult(retIdx));
+    for (size_t i = forOp.getInitArgs().size(); i < newRetIndices.size(); ++i) {
+      warpInput.push_back(newWarpOp.getResult(i));
       argIndexMapping[escapingValues[i]] = warpInputType.size();
       warpInputType.push_back(inputTypes[i]);
     }
@@ -1826,24 +1845,37 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     for (Value args : innerWarp.getBody()->getArguments()) {
       argMapping.push_back(args);
     }
-    argMapping.resize(forOp.getBody()->getNumArguments());
+    auto forOpCopy = cast<scf::ForOp>(rewriter.clone(*forOp.getOperation()));
+    argMapping.resize(forOpCopy.getBody()->getNumArguments());
     SmallVector<Value> yieldOperands;
-    for (Value operand : forOp.getBody()->getTerminator()->getOperands())
+    for (Value operand : forOpCopy.getBody()->getTerminator()->getOperands())
       yieldOperands.push_back(operand);
-    rewriter.eraseOp(forOp.getBody()->getTerminator());
-    rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+
+    rewriter.eraseOp(forOpCopy.getBody()->getTerminator());
+    rewriter.mergeBlocks(forOpCopy.getBody(), innerWarp.getBody(), argMapping);
     rewriter.setInsertionPointToEnd(innerWarp.getBody());
     rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
     rewriter.setInsertionPointAfter(innerWarp);
     if (!innerWarp.getResults().empty())
-      rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
-    rewriter.eraseOp(forOp);
+      rewriter.create<scf::YieldOp>(forOpCopy.getLoc(), innerWarp.getResults());
+    // forOpCopy->getParentOp()->getParentOp()->print(llvm::outs());
+    // llvm::outs() << "\n";
+    // llvm::errs() << "erasing for op\n";
+
+    rewriter.eraseOp(forOpCopy);
     // Replace the warpOp result coming from the original ForOp.
+    // print resultIdx for debugging.
+    llvm::errs() << "resultIdx: ";
+    for (auto idx : resultIdx)
+      llvm::errs() << idx << " ";
+    llvm::errs() << "\n";
     for (const auto &res : llvm::enumerate(resultIdx)) {
       rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
                                   newForOp.getResult(res.index()));
-      newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
+      // newForOp->setOperand(res.index() + 3,
+      // newWarpOp.getResult(res.value()));
     }
+    rewriter.eraseOp(forOp);
     newForOp.walk([&](Operation *op) {
       for (OpOperand &operand : op->getOpOperands()) {
         auto it = argIndexMapping.find(operand.get());
@@ -1852,6 +1884,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
         operand.set(innerWarp.getBodyRegion().getArgument(it->second));
       }
     });
+    newForOp->getParentOp()->print(llvm::outs());
+    llvm::outs() << "\n";
 
     // Finally, hoist out any now uniform code from the inner warp op.
     mlir::vector::moveScalarUniformCode(innerWarp);

>From 4c363175e0c5a0d6cddfe7ac3532051f76d88039 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 2 Jul 2025 17:59:39 +0000
Subject: [PATCH 02/10] working version

---
 .../Vector/Transforms/VectorDistribute.cpp    | 40 ++++++++++++-------
 1 file changed, 25 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 28c957bf61921..dae62d2cecc04 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1796,26 +1796,34 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
                                    escapingValues.begin(),
                                    escapingValues.end());
-
-    SmallVector<size_t> newRetIndices;
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices);
-    yield = cast<gpu::YieldOp>(
-        newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-
-    SmallVector<Value> newOperands;
+    // record result mapping.
     SmallVector<unsigned> resultIdx;
-    // Collect the new init args coming from the new warp op.
-    for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
-      newOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
     for (OpOperand &yieldOperand : yield->getOpOperands()) {
       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
         continue;
       OpResult forResult = cast<OpResult>(yieldOperand.get());
       resultIdx.push_back(forResult.getResultNumber());
-      yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
+      // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
     }
 
+    // SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, yieldedValuesFromWarpOp, distTypes);
+    yield = cast<gpu::YieldOp>(
+        newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+
+    SmallVector<Value> newOperands;
+    // Collect the new init args coming from the new warp op.
+    for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
+      newOperands.push_back(newWarpOp.getResult(i));
+    // for (OpOperand &yieldOperand : yield->getOpOperands()) {
+    //   if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+    //     continue;
+    //   OpResult forResult = cast<OpResult>(yieldOperand.get());
+    //   resultIdx.push_back(forResult.getResultNumber());
+    //   yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
+    // }
+
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
 
@@ -1831,7 +1839,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
                                     forOp.getResultTypes().end());
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
-    for (size_t i = forOp.getInitArgs().size(); i < newRetIndices.size(); ++i) {
+    for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults();
+         ++i) {
       warpInput.push_back(newWarpOp.getResult(i));
       argIndexMapping[escapingValues[i]] = warpInputType.size();
       warpInputType.push_back(inputTypes[i]);
@@ -1870,12 +1879,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       llvm::errs() << idx << " ";
     llvm::errs() << "\n";
     for (const auto &res : llvm::enumerate(resultIdx)) {
-      rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
-                                  newForOp.getResult(res.index()));
+      rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()),
+                                    newForOp.getResult(res.index()), newForOp);
       // newForOp->setOperand(res.index() + 3,
       // newWarpOp.getResult(res.value()));
     }
     rewriter.eraseOp(forOp);
+    rewriter.eraseOp(warpOp);
     newForOp.walk([&](Operation *op) {
       for (OpOperand &operand : op->getOpOperands()) {
         auto it = argIndexMapping.find(operand.get());

>From 3595f1758ad2c71f2f265253cf0a66ec3bdc94d2 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 2 Jul 2025 22:10:56 +0000
Subject: [PATCH 03/10] working version refined

---
 .../Vector/Transforms/VectorDistribute.cpp    | 75 +++++++++++++------
 1 file changed, 52 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index dae62d2cecc04..52a55d104c0bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -1778,37 +1779,53 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     if (llvm::is_contained(distTypes, Type{}))
       return failure();
 
-    llvm::errs() << "escpaing values size: " << escapingValues.size() << "\n";
+    // record result mapping.
+    SmallVector<unsigned> resultIdx;
+    llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
+    for (OpOperand &yieldOperand : yield->getOpOperands()) {
+      if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+        continue;
+      OpResult forResult = cast<OpResult>(yieldOperand.get());
+      resultIdx.push_back(forResult.getResultNumber());
+      forResultToWarpResultMapping[forResult.getResultNumber()] =
+          yieldOperand.getOperandNumber();
+      // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
+    }
+
+    // llvm::errs() << "escpaing values size: " << escapingValues.size() <<
+    // "\n";
 
     SmallVector<Value> yieldedValuesFromWarpOp;
+    SmallVector<Type> yieldedTypesFromWarpOp;
     // All init args of the forOp are yielded from the original warp op.
-    for (Value initArg : forOp.getInitArgs()) {
+    for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
       yieldedValuesFromWarpOp.push_back(initArg);
       // find distributed type for the init arg.
       Type distType = initArg.getType();
       if (auto vecType = dyn_cast<VectorType>(distType)) {
-        AffineMap map = distributionMapFn(initArg);
-        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+        if (forResultToWarpResultMapping.contains(i)) {
+          // If the init arg is yielded from the warp op, we need to compute the
+          // distributed type.
+          distType =
+              warpOp.getResult(forResultToWarpResultMapping[i]).getType();
+        } else {
+          AffineMap map = distributionMapFn(initArg);
+          distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+        }
       }
-      distTypes.push_back(distType);
+      // llvm::errs() << "distributed type: " << distType << "\n";
+      yieldedTypesFromWarpOp.push_back(distType);
     }
     // All escaping values are yielded from the original warp op.
     yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
                                    escapingValues.begin(),
                                    escapingValues.end());
-    // record result mapping.
-    SmallVector<unsigned> resultIdx;
-    for (OpOperand &yieldOperand : yield->getOpOperands()) {
-      if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
-        continue;
-      OpResult forResult = cast<OpResult>(yieldOperand.get());
-      resultIdx.push_back(forResult.getResultNumber());
-      // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
-    }
+    yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(),
+                                  distTypes.begin(), distTypes.end());
 
     // SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-        rewriter, warpOp, yieldedValuesFromWarpOp, distTypes);
+        rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp);
     yield = cast<gpu::YieldOp>(
         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
 
@@ -1839,15 +1856,27 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
                                     forOp.getResultTypes().end());
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
+    // llvm::errs() << "setting arg index mapping\n";
     for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults();
          ++i) {
       warpInput.push_back(newWarpOp.getResult(i));
-      argIndexMapping[escapingValues[i]] = warpInputType.size();
-      warpInputType.push_back(inputTypes[i]);
+      argIndexMapping[escapingValues[i - forOp.getInitArgs().size()]] =
+          warpInputType.size();
+      warpInputType.push_back(inputTypes[i - forOp.getInitArgs().size()]);
     }
+    // for (auto [i, r] : llvm::enumerate(
+    //          newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
+    //          {
+    //   warpInput.push_back(r);
+    //   argIndexMapping[escapingValues[i]] = warpInputType.size();
+    //   warpInputType.push_back(inputTypes[i]);
+    // }
+    // llvm::errs() << "go here\n";
     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
         newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
         newWarpOp.getWarpSize(), warpInput, warpInputType);
+    // newForOp->getParentOp()->print(llvm::outs());
+    // llvm::outs() << "\n";
 
     SmallVector<Value> argMapping;
     argMapping.push_back(newForOp.getInductionVar());
@@ -1874,10 +1903,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     rewriter.eraseOp(forOpCopy);
     // Replace the warpOp result coming from the original ForOp.
     // print resultIdx for debugging.
-    llvm::errs() << "resultIdx: ";
-    for (auto idx : resultIdx)
-      llvm::errs() << idx << " ";
-    llvm::errs() << "\n";
+    // llvm::errs() << "resultIdx: ";
+    // for (auto idx : resultIdx)
+    //   llvm::errs() << idx << " ";
+    // llvm::errs() << "\n";
     for (const auto &res : llvm::enumerate(resultIdx)) {
       rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()),
                                     newForOp.getResult(res.index()), newForOp);
@@ -1894,8 +1923,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
         operand.set(innerWarp.getBodyRegion().getArgument(it->second));
       }
     });
-    newForOp->getParentOp()->print(llvm::outs());
-    llvm::outs() << "\n";
+    // newForOp->getParentOp()->print(llvm::outs());
+    // llvm::outs() << "\n";
 
     // Finally, hoist out any now uniform code from the inner warp op.
     mlir::vector::moveScalarUniformCode(innerWarp);

>From ba94ee21098465ea466c79acaaf01953b34cfc70 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 8 Jul 2025 00:21:19 +0000
Subject: [PATCH 04/10] working failing case now

---
 .../Vector/Transforms/VectorDistribute.cpp    | 81 +++++++++++++------
 1 file changed, 58 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 52a55d104c0bd..adfed18a625b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -19,8 +19,10 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -1779,22 +1781,35 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     if (llvm::is_contained(distTypes, Type{}))
       return failure();
 
+    SmallVector<Value> nonForYieldedValues;
+    // SmallVector<Type> nonForYieldedTypes;
+    SmallVector<unsigned> nonForResultIndices;
+
     // record result mapping.
-    SmallVector<unsigned> resultIdx;
-    llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
+    DenseMap<unsigned, unsigned> forResultMapping;
+    DenseMap<unsigned, unsigned> warpResultMapping;
+    // llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
     for (OpOperand &yieldOperand : yield->getOpOperands()) {
-      if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+      if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+        nonForYieldedValues.push_back(yieldOperand.get());
+        // nonForYieldedTypes.push_back(
+        //     warpOp.getResult(yieldOperand.getOperandNumber()).getType());
+        nonForResultIndices.push_back(yieldOperand.getOperandNumber());
         continue;
+      }
       OpResult forResult = cast<OpResult>(yieldOperand.get());
-      resultIdx.push_back(forResult.getResultNumber());
-      forResultToWarpResultMapping[forResult.getResultNumber()] =
-          yieldOperand.getOperandNumber();
+      forResultMapping[yieldOperand.getOperandNumber()] =
+          forResult.getResultNumber();
+      // forResultToWarpResultMapping[forResult.getResultNumber()] =
+      //     yieldOperand.getOperandNumber();
       // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
     }
 
+    // llvm::errs() << "non for yielded values size: "
+    //              << nonForYieldedValues.size() << "\n";
+
     // llvm::errs() << "escpaing values size: " << escapingValues.size() <<
     // "\n";
-
     SmallVector<Value> yieldedValuesFromWarpOp;
     SmallVector<Type> yieldedTypesFromWarpOp;
     // All init args of the forOp are yielded from the original warp op.
@@ -1803,15 +1818,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       // find distributed type for the init arg.
       Type distType = initArg.getType();
       if (auto vecType = dyn_cast<VectorType>(distType)) {
-        if (forResultToWarpResultMapping.contains(i)) {
-          // If the init arg is yielded from the warp op, we need to compute the
-          // distributed type.
-          distType =
-              warpOp.getResult(forResultToWarpResultMapping[i]).getType();
-        } else {
-          AffineMap map = distributionMapFn(initArg);
-          distType = getDistributedType(vecType, map, warpOp.getWarpSize());
-        }
+        // if (forResultToWarpResultMapping.contains(i)) {
+        //   // If the init arg is yielded from the warp op, we need to compute
+        //   the
+        //   // distributed type.
+        //   distType =
+        //       warpOp.getResult(forResultToWarpResultMapping[i]).getType();
+        // } else {
+        AffineMap map = distributionMapFn(initArg);
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+        // }
       }
       // llvm::errs() << "distributed type: " << distType << "\n";
       yieldedTypesFromWarpOp.push_back(distType);
@@ -1823,12 +1839,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(),
                                   distTypes.begin(), distTypes.end());
 
+    for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
+      warpResultMapping[nonForResultIndices[i]] =
+          yieldedValuesFromWarpOp.size();
+      yieldedValuesFromWarpOp.push_back(v);
+      yieldedTypesFromWarpOp.push_back(
+          warpOp.getResult(nonForResultIndices[i]).getType());
+    }
+
     // SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
         rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp);
     yield = cast<gpu::YieldOp>(
         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
 
+    // newWarpOp->print(llvm::outs());
+    // llvm::outs() << "\n";
+
     SmallVector<Value> newOperands;
     // Collect the new init args coming from the new warp op.
     for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
@@ -1857,12 +1884,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                     forOp.getResultTypes().end());
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
     // llvm::errs() << "setting arg index mapping\n";
-    for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults();
-         ++i) {
+    unsigned escapingValuesStartIdx = forOp.getInitArgs().size();
+    for (size_t i = escapingValuesStartIdx;
+         i < escapingValuesStartIdx + escapingValues.size(); ++i) {
       warpInput.push_back(newWarpOp.getResult(i));
-      argIndexMapping[escapingValues[i - forOp.getInitArgs().size()]] =
+      argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
           warpInputType.size();
-      warpInputType.push_back(inputTypes[i - forOp.getInitArgs().size()]);
+      warpInputType.push_back(inputTypes[i - escapingValuesStartIdx]);
     }
     // for (auto [i, r] : llvm::enumerate(
     //          newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
@@ -1907,9 +1935,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     // for (auto idx : resultIdx)
     //   llvm::errs() << idx << " ";
     // llvm::errs() << "\n";
-    for (const auto &res : llvm::enumerate(resultIdx)) {
-      rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()),
-                                    newForOp.getResult(res.index()), newForOp);
+    for (auto [origIdx, newIdx] : forResultMapping) {
+      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+                                    newForOp.getResult(newIdx), newForOp);
+      // newForOp->setOperand(res.index() + 3,
+      // newWarpOp.getResult(res.value()));
+    }
+
+    for (auto [origIdx, newIdx] : warpResultMapping) {
+      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+                                  newWarpOp.getResult(newIdx));
       // newForOp->setOperand(res.index() + 3,
       // newWarpOp.getResult(res.value()));
     }

>From 28ef9c9400695acb954c34bc48c9a43b710fb92b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 8 Jul 2025 23:27:46 +0000
Subject: [PATCH 05/10] add comments and tests

---
 .../Vector/Transforms/VectorDistribute.cpp    | 217 ++++++++----------
 .../Vector/vector-warp-distribute.mlir        |  79 +++++++
 2 files changed, 172 insertions(+), 124 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index adfed18a625b3..b49c2063b075d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1749,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
+    auto newWarpOpYield = cast<gpu::YieldOp>(
         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     // Only pick up forOp if it is the last op in the region.
-    Operation *lastNode = yield->getPrevNode();
+    Operation *lastNode = newWarpOpYield->getPrevNode();
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
     if (!forOp)
       return failure();
     // Collect Values that come from the warp op but are outside the forOp.
-    // Those Value needs to be returned by the original warpOp and passed to
-    // the new op.
+    // Those Value needs to be returned by the new warp op.
     llvm::SmallSetVector<Value, 32> escapingValues;
-    SmallVector<Type> inputTypes;
-    SmallVector<Type> distTypes;
+    SmallVector<Type> escapingValueInputTypes;
+    SmallVector<Type> escapingValuedistTypes;
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1773,183 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
               AffineMap map = distributionMapFn(operand->get());
               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
             }
-            inputTypes.push_back(operand->get().getType());
-            distTypes.push_back(distType);
+            escapingValueInputTypes.push_back(operand->get().getType());
+            escapingValuedistTypes.push_back(distType);
           }
         });
 
-    if (llvm::is_contained(distTypes, Type{}))
+    if (llvm::is_contained(escapingValuedistTypes, Type{}))
       return failure();
-
+    // Warp op can yield two types of values:
+    // 1. Values that are not results of the forOp:
+    //    These values must also be yielded by the new warp op. Also, we need to
+    //    record the index mapping for these values to replace them later.
+    // 2. Values that are results of the forOp:
+    //    In this case, we record the index mapping between the warp op result
+    //    index and matching forOp result index.
     SmallVector<Value> nonForYieldedValues;
-    // SmallVector<Type> nonForYieldedTypes;
     SmallVector<unsigned> nonForResultIndices;
-
-    // record result mapping.
     DenseMap<unsigned, unsigned> forResultMapping;
-    DenseMap<unsigned, unsigned> warpResultMapping;
-    // llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
-    for (OpOperand &yieldOperand : yield->getOpOperands()) {
+    for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+      // Yielded value is not a result of the forOp.
       if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
         nonForYieldedValues.push_back(yieldOperand.get());
-        // nonForYieldedTypes.push_back(
-        //     warpOp.getResult(yieldOperand.getOperandNumber()).getType());
         nonForResultIndices.push_back(yieldOperand.getOperandNumber());
         continue;
       }
       OpResult forResult = cast<OpResult>(yieldOperand.get());
       forResultMapping[yieldOperand.getOperandNumber()] =
           forResult.getResultNumber();
-      // forResultToWarpResultMapping[forResult.getResultNumber()] =
-      //     yieldOperand.getOperandNumber();
-      // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
     }
 
-    // llvm::errs() << "non for yielded values size: "
-    //              << nonForYieldedValues.size() << "\n";
-
-    // llvm::errs() << "escpaing values size: " << escapingValues.size() <<
-    // "\n";
-    SmallVector<Value> yieldedValuesFromWarpOp;
-    SmallVector<Type> yieldedTypesFromWarpOp;
-    // All init args of the forOp are yielded from the original warp op.
+    // Newly created warp op will yield values in following order:
+    // 1. All init args of the forOp.
+    // 2. All escaping values.
+    // 3. All non-for yielded values.
+    SmallVector<Value> newWarpOpYieldValues;
+    SmallVector<Type> newWarpOpDistTypes;
     for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
-      yieldedValuesFromWarpOp.push_back(initArg);
-      // find distributed type for the init arg.
+      newWarpOpYieldValues.push_back(initArg);
+      // Compute the distributed type for this init arg.
       Type distType = initArg.getType();
       if (auto vecType = dyn_cast<VectorType>(distType)) {
-        // if (forResultToWarpResultMapping.contains(i)) {
-        //   // If the init arg is yielded from the warp op, we need to compute
-        //   the
-        //   // distributed type.
-        //   distType =
-        //       warpOp.getResult(forResultToWarpResultMapping[i]).getType();
-        // } else {
         AffineMap map = distributionMapFn(initArg);
         distType = getDistributedType(vecType, map, warpOp.getWarpSize());
-        // }
       }
-      // llvm::errs() << "distributed type: " << distType << "\n";
-      yieldedTypesFromWarpOp.push_back(distType);
+      newWarpOpDistTypes.push_back(distType);
     }
-    // All escaping values are yielded from the original warp op.
-    yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
-                                   escapingValues.begin(),
-                                   escapingValues.end());
-    yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(),
-                                  distTypes.begin(), distTypes.end());
-
+    // Insert escaping values and their distributed types.
+    newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+                                escapingValues.begin(), escapingValues.end());
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              escapingValuedistTypes.begin(),
+                              escapingValuedistTypes.end());
+    // Next, we insert all non-for yielded values and their distributed types.
+    // We also create a mapping between the non-for yielded value index and the
+    // corresponding new warp op yield value index (needed to update users
+    // later).
+    DenseMap<unsigned, unsigned> warpResultMapping;
     for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
-      warpResultMapping[nonForResultIndices[i]] =
-          yieldedValuesFromWarpOp.size();
-      yieldedValuesFromWarpOp.push_back(v);
-      yieldedTypesFromWarpOp.push_back(
+      warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+      newWarpOpYieldValues.push_back(v);
+      newWarpOpDistTypes.push_back(
           warpOp.getResult(nonForResultIndices[i]).getType());
     }
-
-    // SmallVector<size_t> newRetIndices;
+    // Create the new warp op with the updated yield values and types.
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-        rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp);
-    yield = cast<gpu::YieldOp>(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    newWarpOpYield = cast<gpu::YieldOp>(
         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
 
-    // newWarpOp->print(llvm::outs());
-    // llvm::outs() << "\n";
-
-    SmallVector<Value> newOperands;
-    // Collect the new init args coming from the new warp op.
-    for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
-      newOperands.push_back(newWarpOp.getResult(i));
-    // for (OpOperand &yieldOperand : yield->getOpOperands()) {
-    //   if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
-    //     continue;
-    //   OpResult forResult = cast<OpResult>(yieldOperand.get());
-    //   resultIdx.push_back(forResult.getResultNumber());
-    //   yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
-    // }
+    // Next, we create a new for op with the init args yielded by the new
+    // warp op.
+    unsigned escapingValuesStartIdx =
+        forOp.getInitArgs().size(); // ForOp init args are positioned before
+                                    // escaping values in the new warp op.
+    SmallVector<Value> newForOpOperands;
+    for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+      newForOpOperands.push_back(newWarpOp.getResult(i));
 
+    // Create a new for op outside the new warp op region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
-
-    // Create a new for op outside the region with a WarpExecuteOnLane0Op
-    // region inside.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newOperands);
+        forOp.getStep(), newForOpOperands);
+    // Next, we insert a new warp op (called inner warp op) inside the
+    // newly created for op. This warp op will contain all ops that were
+    // contained within the original for op body.
     rewriter.setInsertionPointToStart(newForOp.getBody());
 
-    SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
-                                 newForOp.getRegionIterArgs().end());
-    SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
-                                    forOp.getResultTypes().end());
+    SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
+                                      newForOp.getRegionIterArgs().end());
+    SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
+                                         forOp.getResultTypes().end());
+    // Escaping values are forwarded to the inner warp op as its (additional)
+    // arguments. We keep track of the mapping between these values and their
+    // argument index in the inner warp op (to replcace uses later).
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
-    // llvm::errs() << "setting arg index mapping\n";
-    unsigned escapingValuesStartIdx = forOp.getInitArgs().size();
     for (size_t i = escapingValuesStartIdx;
          i < escapingValuesStartIdx + escapingValues.size(); ++i) {
-      warpInput.push_back(newWarpOp.getResult(i));
+      innerWarpInput.push_back(newWarpOp.getResult(i));
       argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
-          warpInputType.size();
-      warpInputType.push_back(inputTypes[i - escapingValuesStartIdx]);
+          innerWarpInputType.size();
+      innerWarpInputType.push_back(
+          escapingValueInputTypes[i - escapingValuesStartIdx]);
     }
-    // for (auto [i, r] : llvm::enumerate(
-    //          newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
-    //          {
-    //   warpInput.push_back(r);
-    //   argIndexMapping[escapingValues[i]] = warpInputType.size();
-    //   warpInputType.push_back(inputTypes[i]);
-    // }
-    // llvm::errs() << "go here\n";
+    // Create the inner warp op with the new input values and types.
     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
         newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
-        newWarpOp.getWarpSize(), warpInput, warpInputType);
-    // newForOp->getParentOp()->print(llvm::outs());
-    // llvm::outs() << "\n";
+        newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
 
+    // Inline the for op body into the inner warp op body.
     SmallVector<Value> argMapping;
     argMapping.push_back(newForOp.getInductionVar());
-    for (Value args : innerWarp.getBody()->getArguments()) {
+    for (Value args : innerWarp.getBody()->getArguments())
       argMapping.push_back(args);
-    }
-    auto forOpCopy = cast<scf::ForOp>(rewriter.clone(*forOp.getOperation()));
-    argMapping.resize(forOpCopy.getBody()->getNumArguments());
+
+    argMapping.resize(forOp.getBody()->getNumArguments());
     SmallVector<Value> yieldOperands;
-    for (Value operand : forOpCopy.getBody()->getTerminator()->getOperands())
+    for (Value operand : forOp.getBody()->getTerminator()->getOperands())
       yieldOperands.push_back(operand);
 
-    rewriter.eraseOp(forOpCopy.getBody()->getTerminator());
-    rewriter.mergeBlocks(forOpCopy.getBody(), innerWarp.getBody(), argMapping);
+    rewriter.eraseOp(forOp.getBody()->getTerminator());
+    rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+
+    // Insert a gpu yieldOp at the end of the inner warp op body that yields
+    // original forOp results.
     rewriter.setInsertionPointToEnd(innerWarp.getBody());
     rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
     rewriter.setInsertionPointAfter(innerWarp);
+    // Insert a scf.yield op at the end of the new for op body that yields
+    // the inner warp op results.
     if (!innerWarp.getResults().empty())
-      rewriter.create<scf::YieldOp>(forOpCopy.getLoc(), innerWarp.getResults());
-    // forOpCopy->getParentOp()->getParentOp()->print(llvm::outs());
-    // llvm::outs() << "\n";
-    // llvm::errs() << "erasing for op\n";
-
-    rewriter.eraseOp(forOpCopy);
-    // Replace the warpOp result coming from the original ForOp.
-    // print resultIdx for debugging.
-    // llvm::errs() << "resultIdx: ";
-    // for (auto idx : resultIdx)
-    //   llvm::errs() << idx << " ";
-    // llvm::errs() << "\n";
-    for (auto [origIdx, newIdx] : forResultMapping) {
+      rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
+
+    // Update the users of original warp op results that were coming from the
+    // original forOp to the corresponding new forOp result.
+    for (auto [origIdx, newIdx] : forResultMapping)
       rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
                                     newForOp.getResult(newIdx), newForOp);
-      // newForOp->setOperand(res.index() + 3,
-      // newWarpOp.getResult(res.value()));
-    }
-
-    for (auto [origIdx, newIdx] : warpResultMapping) {
+    // Similarly, update any users of the warp op results that were not
+    // results of the forOp.
+    for (auto [origIdx, newIdx] : warpResultMapping)
       rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
                                   newWarpOp.getResult(newIdx));
-      // newForOp->setOperand(res.index() + 3,
-      // newWarpOp.getResult(res.value()));
-    }
+    // Remove the original warp op and for op, they should not have any uses
+    // at this point.
     rewriter.eraseOp(forOp);
     rewriter.eraseOp(warpOp);
+    // Update any users of escaping values that were forwarded to the
+    // inner warp op. These values are now arguments of the inner warp op.
     newForOp.walk([&](Operation *op) {
       for (OpOperand &operand : op->getOpOperands()) {
         auto it = argIndexMapping.find(operand.get());
@@ -1958,8 +1929,6 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
         operand.set(innerWarp.getBodyRegion().getArgument(it->second));
       }
     });
-    // newForOp->getParentOp()->print(llvm::outs());
-    // llvm::outs() << "\n";
 
     // Finally, hoist out any now uniform code from the inner warp op.
     mlir::vector::moveScalarUniformCode(innerWarp);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 7cfbcdf101d11..3982783c764df 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
   return
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result(
+//       CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:  %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP:  %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP:  gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:  %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:    %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
+//       CHECK-PROP:    %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+//       CHECK-PROP:    gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP:  }
+//       CHECK-PROP:  scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_unused_for_result(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %ini1 = "some_def"() : () -> (vector<128xf32>)
+    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
+      %add = arith.addi %arg3, %c1 : index
+      %1  = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
+      %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#0 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results(
+//       CHECK-PROP:  %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:    %[[INI0:.*]] = "some_def"() : () -> vector<256xf32>
+//  CHECK-PROP-NEXT:    %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//  CHECK-PROP-NEXT:    %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
+//  CHECK-PROP-NEXT:    gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+//  CHECK-PROP-NEXT:  }
+//  CHECK-PROP-NEXT:  %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:    %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} :
+//  CHECK-PROP-SAME:        vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:      ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>):
+//  CHECK-PROP-NEXT:        %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32>
+//  CHECK-PROP-NEXT:        %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+//  CHECK-PROP-NEXT:        %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+//  CHECK-PROP-NEXT:        gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+//  CHECK-PROP-NEXT:    }
+//  CHECK-PROP-NEXT:    scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32>
+//  CHECK-PROP-NEXT:  }
+//  CHECK-PROP-NEXT:  "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> ()
+//  CHECK-PROP-NEXT:  "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> ()
+//  CHECK-PROP-NEXT:  "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> ()
+func.func @warp_scf_for_swapped_for_results(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
+    %ini1 = "some_def"() : () -> (vector<256xf32>)
+    %ini2 = "some_def"() : () -> (vector<128xf32>)
+    %ini3 = "some_def"() : () -> (vector<128xf32>)
+    %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
+      %acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>)
+      %acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
+      %acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32>
+  }
+  "some_use_1"(%0#0) : (vector<4xf32>) -> ()
+  "some_use_2"(%0#1) : (vector<4xf32>) -> ()
+  "some_use_3"(%0#2) : (vector<8xf32>) -> ()
+  return
+}
+
 // -----
 
 // CHECK-PROP-LABEL: func @vector_reduction(

>From 537ca0e285c9a220a0dc0d53e24f51c86e81d5e7 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 02:46:33 +0000
Subject: [PATCH 06/10] add missing logic

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 24 ++++++++++++++++---
 1 file changed, 21 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c072557c2bd22..ef257307de569 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -34,6 +34,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 
 namespace mlir {
 namespace xegpu {
@@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   // Step 3: Apply subgroup to workitem distribution patterns.
   RewritePatternSet patterns(&getContext());
   xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
-  // TODO: distributionFn and shuffleFn are not used at this point.
+  // distributionFn is used by vector distribution patterns to determine the
+  // distributed vector type for a given vector value. In XeGPU subgroup
+  // distribution context, we compute this based on lane layout.
   auto distributionFn = [](Value val) {
     VectorType vecType = dyn_cast<VectorType>(val.getType());
     int64_t vecRank = vecType ? vecType.getRank() : 0;
-    OpBuilder builder(val.getContext());
     if (vecRank == 0)
       return AffineMap::get(val.getContext());
-    return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+    // Get the layout of the vector type.
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
+    // If no layout is specified, assume the inner most dimension is distributed
+    // for now.
+    if (!layout)
+      return AffineMap::getMultiDimMapWithTargets(
+          vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
+    SmallVector<unsigned int> distributedDims;
+    // Get the distributed dimensions based on the layout.
+    ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
+    for (unsigned i = 0; i < laneLayout.size(); ++i) {
+      if (laneLayout[i] > 1)
+        distributedDims.push_back(i);
+    }
+    return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
+                                                val.getContext());
   };
+  // TODO: shuffleFn is not used.
   auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
                       int64_t warpSz) { return Value(); };
   vector::populatePropagateWarpVectorDistributionPatterns(

>From 164e9d619880ae24388234edcd9297b66c31d209 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 17:57:13 +0000
Subject: [PATCH 07/10] address comments

---
 .../Vector/Transforms/VectorDistribute.cpp    | 88 +++++++++----------
 1 file changed, 44 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3ce134fe5f3ce..7d3d6b98666a1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1751,13 +1751,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                 PatternRewriter &rewriter) const override {
     auto newWarpOpYield = cast<gpu::YieldOp>(
         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-    // Only pick up forOp if it is the last op in the region.
+    // Only pick up `ForOp` if it is the last op in the region.
     Operation *lastNode = newWarpOpYield->getPrevNode();
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
     if (!forOp)
       return failure();
-    // Collect Values that come from the warp op but are outside the forOp.
-    // Those Value needs to be returned by the new warp op.
+    // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
+    // Those Values need to be returned by the new warp op.
     llvm::SmallSetVector<Value, 32> escapingValues;
     SmallVector<Type> escapingValueInputTypes;
     SmallVector<Type> escapingValuedistTypes;
@@ -1779,16 +1779,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
 
     if (llvm::is_contained(escapingValuedistTypes, Type{}))
       return failure();
-    // Warp op can yield two types of values:
-    // 1. Values that are not results of the forOp:
-    //    These values must also be yielded by the new warp op. Also, we need to
-    //    record the index mapping for these values to replace them later.
-    // 2. Values that are results of the forOp:
-    //    In this case, we record the index mapping between the warp op result
-    //    index and matching forOp result index.
+    // `WarpOp` can yield two types of values:
+    // 1. Values that are not results of the `ForOp`:
+    //    These values must also be yielded by the new `WarpOp`. Also, we need
+    //    to record the index mapping for these values to replace them later.
+    // 2. Values that are results of the `ForOp`:
+    //    In this case, we record the index mapping between the `WarpOp` result
+    //    index and matching `ForOp` result index.
     SmallVector<Value> nonForYieldedValues;
     SmallVector<unsigned> nonForResultIndices;
-    DenseMap<unsigned, unsigned> forResultMapping;
+    llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
     for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
       // Yielded value is not a result of the forOp.
       if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
@@ -1801,10 +1801,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
           forResult.getResultNumber();
     }
 
-    // Newly created warp op will yield values in following order:
-    // 1. All init args of the forOp.
+    // Newly created `WarpOp` will yield values in following order:
+    // 1. All init args of the `ForOp`.
     // 2. All escaping values.
-    // 3. All non-for yielded values.
+    // 3. All non-`ForOp` yielded values.
     SmallVector<Value> newWarpOpYieldValues;
     SmallVector<Type> newWarpOpDistTypes;
     for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
@@ -1823,50 +1823,50 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
                               escapingValuedistTypes.begin(),
                               escapingValuedistTypes.end());
-    // Next, we insert all non-for yielded values and their distributed types.
-    // We also create a mapping between the non-for yielded value index and the
-    // corresponding new warp op yield value index (needed to update users
-    // later).
-    DenseMap<unsigned, unsigned> warpResultMapping;
+    // Next, we insert all non-`ForOp` yielded values and their distributed
+    // types. We also create a mapping between the non-`ForOp` yielded value
+    // index and the corresponding new `WarpOp` yield value index (needed to
+    // update users later).
+    llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping;
     for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
       warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
       newWarpOpYieldValues.push_back(v);
       newWarpOpDistTypes.push_back(
           warpOp.getResult(nonForResultIndices[i]).getType());
     }
-    // Create the new warp op with the updated yield values and types.
+    // Create the new `WarpOp` with the updated yield values and types.
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
         rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
     newWarpOpYield = cast<gpu::YieldOp>(
         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
 
-    // Next, we create a new for op with the init args yielded by the new
-    // warp op.
-    unsigned escapingValuesStartIdx =
-        forOp.getInitArgs().size(); // ForOp init args are positioned before
-                                    // escaping values in the new warp op.
+    // Next, we create a new `ForOp` with the init args yielded by the new
+    // `WarpOp`.
+    const unsigned escapingValuesStartIdx =
+        forOp.getInitArgs().size(); // `ForOp` init args are positioned before
+                                    // escaping values in the new `WarpOp`.
     SmallVector<Value> newForOpOperands;
     for (size_t i = 0; i < escapingValuesStartIdx; ++i)
       newForOpOperands.push_back(newWarpOp.getResult(i));
 
-    // Create a new for op outside the new warp op region.
+    // Create a new `ForOp` outside the new `WarpOp` region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), newForOpOperands);
-    // Next, we insert a new warp op (called inner warp op) inside the
-    // newly created for op. This warp op will contain all ops that were
-    // contained within the original for op body.
+    // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
+    // newly created `ForOp`. This `WarpOp` will contain all ops that were
+    // contained within the original `ForOp` body.
     rewriter.setInsertionPointToStart(newForOp.getBody());
 
     SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
                                       newForOp.getRegionIterArgs().end());
     SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
                                          forOp.getResultTypes().end());
-    // Escaping values are forwarded to the inner warp op as its (additional)
+    // Escaping values are forwarded to the inner `WarpOp` as its (additional)
     // arguments. We keep track of the mapping between these values and their
-    // argument index in the inner warp op (to replcace uses later).
+    // argument index in the inner `WarpOp` (to replace users later).
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
     for (size_t i = escapingValuesStartIdx;
          i < escapingValuesStartIdx + escapingValues.size(); ++i) {
@@ -1876,12 +1876,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       innerWarpInputType.push_back(
           escapingValueInputTypes[i - escapingValuesStartIdx]);
     }
-    // Create the inner warp op with the new input values and types.
+    // Create the inner `WarpOp` with the new input values and types.
     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
         newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
         newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
 
-    // Inline the for op body into the inner warp op body.
+    // Inline the `ForOp` body into the inner `WarpOp` body.
     SmallVector<Value> argMapping;
     argMapping.push_back(newForOp.getInductionVar());
     for (Value args : innerWarp.getBody()->getArguments())
@@ -1895,32 +1895,32 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     rewriter.eraseOp(forOp.getBody()->getTerminator());
     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
 
-    // Insert a gpu yieldOp at the end of the inner warp op body that yields
-    // original forOp results.
+    // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
+    // original `ForOp` results.
     rewriter.setInsertionPointToEnd(innerWarp.getBody());
     rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
     rewriter.setInsertionPointAfter(innerWarp);
-    // Insert a scf.yield op at the end of the new for op body that yields
-    // the inner warp op results.
+    // Insert a scf.yield op at the end of the new `ForOp` body that yields
+    // the inner `WarpOp` results.
     if (!innerWarp.getResults().empty())
       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
 
-    // Update the users of original warp op results that were coming from the
-    // original forOp to the corresponding new forOp result.
+    // Update the users of original `WarpOp` results that were coming from the
+    // original `ForOp` to the corresponding new `ForOp` result.
     for (auto [origIdx, newIdx] : forResultMapping)
       rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
                                     newForOp.getResult(newIdx), newForOp);
-    // Similarly, update any users of the warp op results that were not
-    // results of the forOp.
+    // Similarly, update any users of the `WarpOp` results that were not
+    // results of the `ForOp`.
     for (auto [origIdx, newIdx] : warpResultMapping)
       rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
                                   newWarpOp.getResult(newIdx));
-    // Remove the original warp op and for op, they should not have any uses
+    // Remove the original `WarpOp` and `ForOp`, they should not have any uses
     // at this point.
     rewriter.eraseOp(forOp);
     rewriter.eraseOp(warpOp);
     // Update any users of escaping values that were forwarded to the
-    // inner warp op. These values are now arguments of the inner warp op.
+    // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
     newForOp.walk([&](Operation *op) {
       for (OpOperand &operand : op->getOpOperands()) {
         auto it = argIndexMapping.find(operand.get());
@@ -1930,7 +1930,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       }
     });
 
-    // Finally, hoist out any now uniform code from the inner warp op.
+    // Finally, hoist out any now uniform code from the inner `WarpOp`.
     mlir::vector::moveScalarUniformCode(innerWarp);
     return success();
   }

>From 683fad8092e6c40a1bc95bf2f249653236f72960 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 10 Jul 2025 22:18:19 +0000
Subject: [PATCH 08/10] address comments

---
 .../lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7d3d6b98666a1..f4928ee6c4221 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1760,7 +1760,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     // Those Values need to be returned by the new warp op.
     llvm::SmallSetVector<Value, 32> escapingValues;
     SmallVector<Type> escapingValueInputTypes;
-    SmallVector<Type> escapingValuedistTypes;
+    SmallVector<Type> escapingValueDistTypes;
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1773,11 +1773,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
             }
             escapingValueInputTypes.push_back(operand->get().getType());
-            escapingValuedistTypes.push_back(distType);
+            escapingValueDistTypes.push_back(distType);
           }
         });
 
-    if (llvm::is_contained(escapingValuedistTypes, Type{}))
+    if (llvm::is_contained(escapingValueDistTypes, Type{}))
       return failure();
     // `WarpOp` can yield two types of values:
     // 1. Values that are not results of the `ForOp`:
@@ -1821,8 +1821,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
                                 escapingValues.begin(), escapingValues.end());
     newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
-                              escapingValuedistTypes.begin(),
-                              escapingValuedistTypes.end());
+                              escapingValueDistTypes.begin(),
+                              escapingValueDistTypes.end());
     // Next, we insert all non-`ForOp` yielded values and their distributed
     // types. We also create a mapping between the non-`ForOp` yielded value
     // index and the corresponding new `WarpOp` yield value index (needed to

>From 2c2703eb9511ff8b2df44de23ffa0efb15308772 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 11 Jul 2025 16:26:37 +0000
Subject: [PATCH 09/10] address comments

---
 .../Dialect/Vector/Transforms/VectorDistribute.cpp | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index f4928ee6c4221..af7c34f354668 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1749,10 +1749,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto newWarpOpYield = cast<gpu::YieldOp>(
+    auto warpOpYield = cast<gpu::YieldOp>(
         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     // Only pick up `ForOp` if it is the last op in the region.
-    Operation *lastNode = newWarpOpYield->getPrevNode();
+    Operation *lastNode = warpOpYield->getPrevNode();
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
     if (!forOp)
       return failure();
@@ -1789,7 +1789,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     SmallVector<Value> nonForYieldedValues;
     SmallVector<unsigned> nonForResultIndices;
     llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
-    for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+    for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
       // Yielded value is not a result of the forOp.
       if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
         nonForYieldedValues.push_back(yieldOperand.get());
@@ -1827,9 +1827,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     // types. We also create a mapping between the non-`ForOp` yielded value
     // index and the corresponding new `WarpOp` yield value index (needed to
     // update users later).
-    llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping;
+    llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
     for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
-      warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+      nonForResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
       newWarpOpYieldValues.push_back(v);
       newWarpOpDistTypes.push_back(
           warpOp.getResult(nonForResultIndices[i]).getType());
@@ -1837,8 +1837,6 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     // Create the new `WarpOp` with the updated yield values and types.
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
         rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
-    newWarpOpYield = cast<gpu::YieldOp>(
-        newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
 
     // Next, we create a new `ForOp` with the init args yielded by the new
     // `WarpOp`.
@@ -1912,7 +1910,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                     newForOp.getResult(newIdx), newForOp);
     // Similarly, update any users of the `WarpOp` results that were not
     // results of the `ForOp`.
-    for (auto [origIdx, newIdx] : warpResultMapping)
+    for (auto [origIdx, newIdx] : nonForResultMapping)
       rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
                                   newWarpOp.getResult(newIdx));
     // Remove the original `WarpOp` and `ForOp`, they should not have any uses

>From 6297e47a62c8215fb25cd06f165241c10763d23c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 11 Jul 2025 19:47:05 +0000
Subject: [PATCH 10/10] address comments

---
 .../Dialect/Vector/Transforms/VectorDistribute.cpp   | 12 ++++--------
 1 file changed, 4 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 15c14bef37e76..e62031412eab6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -17,12 +17,8 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -1787,11 +1783,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     // index and the corresponding new `WarpOp` yield value index (needed to
     // update users later).
     llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
-    for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
-      nonForResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+    for (auto [i, v] :
+         llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
+      nonForResultMapping[i] = newWarpOpYieldValues.size();
       newWarpOpYieldValues.push_back(v);
-      newWarpOpDistTypes.push_back(
-          warpOp.getResult(nonForResultIndices[i]).getType());
+      newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
     }
     // Create the new `WarpOp` with the updated yield values and types.
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(



More information about the Mlir-commits mailing list