[Mlir-commits] [mlir] [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (PR #147620)

Charitha Saumya llvmlistbot at llvm.org
Fri Jul 11 12:44:52 PDT 2025


================
@@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
               AffineMap map = distributionMapFn(operand->get());
               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
             }
-            inputTypes.push_back(operand->get().getType());
-            distTypes.push_back(distType);
+            escapingValueInputTypes.push_back(operand->get().getType());
+            escapingValuedistTypes.push_back(distType);
           }
         });
 
-    if (llvm::is_contained(distTypes, Type{}))
+    if (llvm::is_contained(escapingValuedistTypes, Type{}))
       return failure();
+    // `WarpOp` can yield two types of values:
+    // 1. Values that are not results of the `ForOp`:
+    //    These values must also be yielded by the new `WarpOp`. Also, we need
+    //    to record the index mapping for these values to replace them later.
+    // 2. Values that are results of the `ForOp`:
+    //    In this case, we record the index mapping between the `WarpOp` result
+    //    index and matching `ForOp` result index.
+    SmallVector<Value> nonForYieldedValues;
+    SmallVector<unsigned> nonForResultIndices;
+    llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
+    for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+      // Yielded value is not a result of the forOp.
+      if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+        nonForYieldedValues.push_back(yieldOperand.get());
+        nonForResultIndices.push_back(yieldOperand.getOperandNumber());
+        continue;
+      }
+      OpResult forResult = cast<OpResult>(yieldOperand.get());
+      forResultMapping[yieldOperand.getOperandNumber()] =
+          forResult.getResultNumber();
+    }
 
-    SmallVector<size_t> newRetIndices;
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
-        newRetIndices);
-    yield = cast<gpu::YieldOp>(
+    // Newly created `WarpOp` will yield values in following order:
+    // 1. All init args of the `ForOp`.
+    // 2. All escaping values.
+    // 3. All non-`ForOp` yielded values.
+    SmallVector<Value> newWarpOpYieldValues;
+    SmallVector<Type> newWarpOpDistTypes;
+    for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
+      newWarpOpYieldValues.push_back(initArg);
+      // Compute the distributed type for this init arg.
+      Type distType = initArg.getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(initArg);
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      newWarpOpDistTypes.push_back(distType);
+    }
+    // Insert escaping values and their distributed types.
+    newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+                                escapingValues.begin(), escapingValues.end());
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              escapingValuedistTypes.begin(),
+                              escapingValuedistTypes.end());
+    // Next, we insert all non-`ForOp` yielded values and their distributed
+    // types. We also create a mapping between the non-`ForOp` yielded value
+    // index and the corresponding new `WarpOp` yield value index (needed to
+    // update users later).
+    llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping;
+    for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
----------------
charithaintc wrote:

ah I see. it makes sense now. fixed it. 

https://github.com/llvm/llvm-project/pull/147620


More information about the Mlir-commits mailing list