[Mlir-commits] [mlir] [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (PR #147620)
Charitha Saumya
llvmlistbot at llvm.org
Fri Jul 11 12:44:52 PDT 2025
================
@@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
- inputTypes.push_back(operand->get().getType());
- distTypes.push_back(distType);
+ escapingValueInputTypes.push_back(operand->get().getType());
+ escapingValuedistTypes.push_back(distType);
}
});
- if (llvm::is_contained(distTypes, Type{}))
+ if (llvm::is_contained(escapingValuedistTypes, Type{}))
return failure();
+ // `WarpOp` can yield two types of values:
+ // 1. Values that are not results of the `ForOp`:
+ // These values must also be yielded by the new `WarpOp`. Also, we need
+ // to record the index mapping for these values to replace them later.
+ // 2. Values that are results of the `ForOp`:
+ // In this case, we record the index mapping between the `WarpOp` result
+ // index and matching `ForOp` result index.
+ SmallVector<Value> nonForYieldedValues;
+ SmallVector<unsigned> nonForResultIndices;
+ llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
+ for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+ // Yielded value is not a result of the forOp.
+ if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+ nonForYieldedValues.push_back(yieldOperand.get());
+ nonForResultIndices.push_back(yieldOperand.getOperandNumber());
+ continue;
+ }
+ OpResult forResult = cast<OpResult>(yieldOperand.get());
+ forResultMapping[yieldOperand.getOperandNumber()] =
+ forResult.getResultNumber();
+ }
- SmallVector<size_t> newRetIndices;
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
- newRetIndices);
- yield = cast<gpu::YieldOp>(
+ // Newly created `WarpOp` will yield values in following order:
+ // 1. All init args of the `ForOp`.
+ // 2. All escaping values.
+ // 3. All non-`ForOp` yielded values.
+ SmallVector<Value> newWarpOpYieldValues;
+ SmallVector<Type> newWarpOpDistTypes;
+ for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
+ newWarpOpYieldValues.push_back(initArg);
+ // Compute the distributed type for this init arg.
+ Type distType = initArg.getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(initArg);
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ newWarpOpDistTypes.push_back(distType);
+ }
+ // Insert escaping values and their distributed types.
+ newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+ escapingValues.begin(), escapingValues.end());
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ escapingValuedistTypes.begin(),
+ escapingValuedistTypes.end());
+ // Next, we insert all non-`ForOp` yielded values and their distributed
+ // types. We also create a mapping between the non-`ForOp` yielded value
+ // index and the corresponding new `WarpOp` yield value index (needed to
+ // update users later).
+ llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping;
+ for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
----------------
charithaintc wrote:
ah I see. it makes sense now. fixed it.
https://github.com/llvm/llvm-project/pull/147620
More information about the Mlir-commits
mailing list