[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `scf.if` (PR #157119)
Charitha Saumya
llvmlistbot at llvm.org
Tue Sep 9 23:01:39 PDT 2025
================
@@ -1713,6 +1713,209 @@ struct WarpOpInsert : public WarpDistributionPattern {
}
};
+struct WarpOpScfIfOp : public WarpDistributionPattern {
+ WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp warpOpYield = warpOp.getTerminator();
+ // Only pick up `IfOp` if it is the last op in the region.
+ Operation *lastNode = warpOpYield->getPrevNode();
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
+ if (!ifOp)
+ return failure();
+
+ // The current `WarpOp` can yield two types of values:
+ // 1. Not results of `IfOp`:
+ // Preserve them in the new `WarpOp`.
+ // Collect their yield index.
+ // 2. Results of `IfOp`:
+ // They are not part of the new `WarpOp` results.
+ // Map current warp's yield operand index to `IfOp` result idx.
+ SmallVector<Value> nonIfYieldValues;
+ SmallVector<unsigned> nonIfYieldIndices;
+ llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
+ llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+ const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
+ if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
+ nonIfYieldValues.push_back(yieldOperand.get());
+ nonIfYieldIndices.push_back(yieldOperandIdx);
+ continue;
+ }
+ OpResult ifResult = cast<OpResult>(yieldOperand.get());
+ const unsigned ifResultIdx = ifResult.getResultNumber();
+ ifResultMapping[yieldOperandIdx] = ifResultIdx;
+ // If this `ifOp` result is vector type and it is yielded by the
+ // `WarpOp`, we keep track the distributed type for this result.
+ if (!isa<VectorType>(ifResult.getType()))
+ continue;
+ VectorType distType =
+ cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
+ ifResultDistTypes[ifResultIdx] = distType;
+ }
+
+ // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
+ // them
+ auto getEscapingValues = [&](Region &branch,
+ llvm::SmallSetVector<Value, 32> &values,
+ SmallVector<Type> &inputTypes,
+ SmallVector<Type> &distTypes) {
+ if (branch.empty())
+ return;
+ mlir::visitUsedValuesDefinedAbove(branch, [&](OpOperand *operand) {
+ Operation *parent = operand->get().getParentRegion()->getParentOp();
+ if (warpOp->isAncestor(parent)) {
+ if (!values.insert(operand->get()))
+ return;
+ Type distType = operand->get().getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(operand->get());
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ inputTypes.push_back(operand->get().getType());
+ distTypes.push_back(distType);
+ }
+ });
+ };
+ llvm::SmallSetVector<Value, 32> escapingValuesThen;
+ SmallVector<Type> escapingValueInputTypesThen; // inner warp op block args
+ SmallVector<Type> escapingValueDistTypesThen; // new warp returns
+ getEscapingValues(ifOp.getThenRegion(), escapingValuesThen,
+ escapingValueInputTypesThen, escapingValueDistTypesThen);
+ llvm::SmallSetVector<Value, 32> escapingValuesElse;
+ SmallVector<Type> escapingValueInputTypesElse; // inner warp op block args
+ SmallVector<Type> escapingValueDistTypesElse; // new warp returns
+ getEscapingValues(ifOp.getElseRegion(), escapingValuesElse,
+ escapingValueInputTypesElse, escapingValueDistTypesElse);
+
+ if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
+ llvm::is_contained(escapingValueDistTypesElse, Type{}))
+ return failure();
+
+ // The new `WarpOp` groups yields values in following order:
+ // 1. Branch condition
+ // 2. Escaping values then branch
+ // 3. Escaping values else branch
+ // 4. All non-`ifOp` yielded values.
+ SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
+ newWarpOpYieldValues.append(escapingValuesThen.begin(),
+ escapingValuesThen.end());
+ newWarpOpYieldValues.append(escapingValuesElse.begin(),
+ escapingValuesElse.end());
+ SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
+ newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
+ escapingValueDistTypesThen.end());
+ newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
+ escapingValueDistTypesElse.end());
+
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
+ for (auto [idx, val] :
+ llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
+ origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
+ newWarpOpYieldValues.push_back(val);
+ newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
+ }
+ // Create the new `WarpOp` with the updated yield values and types.
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ // `ifOp` returns the result of the inner warp op.
+ SmallVector<Type> newIfOpDistResTypes;
+ for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
+ Type distType = cast<Value>(res).getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(cast<Value>(res));
+ distType = ifResultDistTypes.count(i)
+ ? ifResultDistTypes[i]
+ : getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ newIfOpDistResTypes.push_back(distType);
+ }
+ // Create a new `IfOp` outside the new `WarpOp` region.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newIfOp = scf::IfOp::create(
+ rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
+ static_cast<bool>(ifOp.thenBlock()),
+ static_cast<bool>(ifOp.elseBlock()));
+
+ auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
----------------
charithaintc wrote:
let's go back to original lambda version then. That looks cleaner. sorry about more work. :-)
https://github.com/llvm/llvm-project/pull/157119
More information about the Mlir-commits
mailing list