[Mlir-commits] [mlir] [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (PR #147620)

Petr Kurapov llvmlistbot at llvm.org
Wed Jul 9 09:42:18 PDT 2025


================
@@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
               AffineMap map = distributionMapFn(operand->get());
               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
             }
-            inputTypes.push_back(operand->get().getType());
-            distTypes.push_back(distType);
+            escapingValueInputTypes.push_back(operand->get().getType());
+            escapingValuedistTypes.push_back(distType);
           }
         });
 
-    if (llvm::is_contained(distTypes, Type{}))
+    if (llvm::is_contained(escapingValuedistTypes, Type{}))
       return failure();
+    // Warp op can yield two types of values:
+    // 1. Values that are not results of the forOp:
+    //    These values must also be yielded by the new warp op. Also, we need to
+    //    record the index mapping for these values to replace them later.
+    // 2. Values that are results of the forOp:
+    //    In this case, we record the index mapping between the warp op result
+    //    index and matching forOp result index.
+    SmallVector<Value> nonForYieldedValues;
+    SmallVector<unsigned> nonForResultIndices;
+    DenseMap<unsigned, unsigned> forResultMapping;
+    for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+      // Yielded value is not a result of the forOp.
+      if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+        nonForYieldedValues.push_back(yieldOperand.get());
+        nonForResultIndices.push_back(yieldOperand.getOperandNumber());
+        continue;
+      }
+      OpResult forResult = cast<OpResult>(yieldOperand.get());
+      forResultMapping[yieldOperand.getOperandNumber()] =
+          forResult.getResultNumber();
+    }
 
-    SmallVector<size_t> newRetIndices;
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
-        newRetIndices);
-    yield = cast<gpu::YieldOp>(
+    // Newly created warp op will yield values in following order:
+    // 1. All init args of the forOp.
+    // 2. All escaping values.
+    // 3. All non-for yielded values.
+    SmallVector<Value> newWarpOpYieldValues;
+    SmallVector<Type> newWarpOpDistTypes;
+    for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
+      newWarpOpYieldValues.push_back(initArg);
+      // Compute the distributed type for this init arg.
+      Type distType = initArg.getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(initArg);
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      newWarpOpDistTypes.push_back(distType);
+    }
+    // Insert escaping values and their distributed types.
+    newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+                                escapingValues.begin(), escapingValues.end());
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              escapingValuedistTypes.begin(),
+                              escapingValuedistTypes.end());
+    // Next, we insert all non-for yielded values and their distributed types.
+    // We also create a mapping between the non-for yielded value index and the
+    // corresponding new warp op yield value index (needed to update users
+    // later).
+    DenseMap<unsigned, unsigned> warpResultMapping;
+    for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
+      warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+      newWarpOpYieldValues.push_back(v);
+      newWarpOpDistTypes.push_back(
+          warpOp.getResult(nonForResultIndices[i]).getType());
+    }
+    // Create the new warp op with the updated yield values and types.
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    newWarpOpYield = cast<gpu::YieldOp>(
         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
 
-    SmallVector<Value> newOperands;
-    SmallVector<unsigned> resultIdx;
-    // Collect all the outputs coming from the forOp.
-    for (OpOperand &yieldOperand : yield->getOpOperands()) {
-      if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
-        continue;
-      auto forResult = cast<OpResult>(yieldOperand.get());
-      newOperands.push_back(
-          newWarpOp.getResult(yieldOperand.getOperandNumber()));
-      yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
-      resultIdx.push_back(yieldOperand.getOperandNumber());
-    }
+    // Next, we create a new for op with the init args yielded by the new
+    // warp op.
+    unsigned escapingValuesStartIdx =
+        forOp.getInitArgs().size(); // ForOp init args are positioned before
+                                    // escaping values in the new warp op.
+    SmallVector<Value> newForOpOperands;
+    for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+      newForOpOperands.push_back(newWarpOp.getResult(i));
 
+    // Create a new for op outside the new warp op region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
-
-    // Create a new for op outside the region with a WarpExecuteOnLane0Op
-    // region inside.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newOperands);
+        forOp.getStep(), newForOpOperands);
+    // Next, we insert a new warp op (called inner warp op) inside the
+    // newly created for op. This warp op will contain all ops that were
----------------
kurapov-peter wrote:

btw, the comments would be easier to read if they highlight the op names, e.g.
```suggestion
    // newly created `ForOp`. This warp op will contain all ops that were
```

https://github.com/llvm/llvm-project/pull/147620


More information about the Mlir-commits mailing list