[Mlir-commits] [mlir] [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (PR #147620)
Petr Kurapov
llvmlistbot at llvm.org
Wed Jul 9 09:42:18 PDT 2025
================
@@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
- inputTypes.push_back(operand->get().getType());
- distTypes.push_back(distType);
+ escapingValueInputTypes.push_back(operand->get().getType());
+ escapingValuedistTypes.push_back(distType);
}
});
- if (llvm::is_contained(distTypes, Type{}))
+ if (llvm::is_contained(escapingValuedistTypes, Type{}))
return failure();
+ // Warp op can yield two types of values:
+ // 1. Values that are not results of the forOp:
+ // These values must also be yielded by the new warp op. Also, we need to
+ // record the index mapping for these values to replace them later.
+ // 2. Values that are results of the forOp:
+ // In this case, we record the index mapping between the warp op result
+ // index and matching forOp result index.
+ SmallVector<Value> nonForYieldedValues;
+ SmallVector<unsigned> nonForResultIndices;
+ DenseMap<unsigned, unsigned> forResultMapping;
+ for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+ // Yielded value is not a result of the forOp.
+ if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+ nonForYieldedValues.push_back(yieldOperand.get());
+ nonForResultIndices.push_back(yieldOperand.getOperandNumber());
+ continue;
+ }
+ OpResult forResult = cast<OpResult>(yieldOperand.get());
+ forResultMapping[yieldOperand.getOperandNumber()] =
+ forResult.getResultNumber();
+ }
- SmallVector<size_t> newRetIndices;
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
- newRetIndices);
- yield = cast<gpu::YieldOp>(
+ // Newly created warp op will yield values in following order:
+ // 1. All init args of the forOp.
+ // 2. All escaping values.
+ // 3. All non-for yielded values.
+ SmallVector<Value> newWarpOpYieldValues;
+ SmallVector<Type> newWarpOpDistTypes;
+ for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
+ newWarpOpYieldValues.push_back(initArg);
+ // Compute the distributed type for this init arg.
+ Type distType = initArg.getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(initArg);
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ newWarpOpDistTypes.push_back(distType);
+ }
+ // Insert escaping values and their distributed types.
+ newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+ escapingValues.begin(), escapingValues.end());
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ escapingValuedistTypes.begin(),
+ escapingValuedistTypes.end());
+ // Next, we insert all non-for yielded values and their distributed types.
+ // We also create a mapping between the non-for yielded value index and the
+ // corresponding new warp op yield value index (needed to update users
+ // later).
+ DenseMap<unsigned, unsigned> warpResultMapping;
+ for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
+ warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+ newWarpOpYieldValues.push_back(v);
+ newWarpOpDistTypes.push_back(
+ warpOp.getResult(nonForResultIndices[i]).getType());
+ }
+ // Create the new warp op with the updated yield values and types.
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ newWarpOpYield = cast<gpu::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- SmallVector<Value> newOperands;
- SmallVector<unsigned> resultIdx;
- // Collect all the outputs coming from the forOp.
- for (OpOperand &yieldOperand : yield->getOpOperands()) {
- if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
- continue;
- auto forResult = cast<OpResult>(yieldOperand.get());
- newOperands.push_back(
- newWarpOp.getResult(yieldOperand.getOperandNumber()));
- yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
- resultIdx.push_back(yieldOperand.getOperandNumber());
- }
+ // Next, we create a new for op with the init args yielded by the new
+ // warp op.
+ unsigned escapingValuesStartIdx =
+ forOp.getInitArgs().size(); // ForOp init args are positioned before
+ // escaping values in the new warp op.
+ SmallVector<Value> newForOpOperands;
+ for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+ newForOpOperands.push_back(newWarpOp.getResult(i));
+ // Create a new for op outside the new warp op region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
-
- // Create a new for op outside the region with a WarpExecuteOnLane0Op
- // region inside.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newOperands);
+ forOp.getStep(), newForOpOperands);
+ // Next, we insert a new warp op (called inner warp op) inside the
+ // newly created for op. This warp op will contain all ops that were
----------------
kurapov-peter wrote:
btw, the comments would be easier to read if they highlight the op names, e.g.
```suggestion
// newly created `ForOp`. This warp op will contain all ops that were
```
https://github.com/llvm/llvm-project/pull/147620
More information about the Mlir-commits
mailing list