[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `scf.if` (PR #157119)
Artem Kroviakov
llvmlistbot at llvm.org
Tue Sep 9 03:20:57 PDT 2025
================
@@ -1713,6 +1713,209 @@ struct WarpOpInsert : public WarpDistributionPattern {
}
};
+struct WarpOpScfIfOp : public WarpDistributionPattern {
+ WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp warpOpYield = warpOp.getTerminator();
+ // Only pick up `IfOp` if it is the last op in the region.
+ Operation *lastNode = warpOpYield->getPrevNode();
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
+ if (!ifOp)
+ return failure();
+
+ // The current `WarpOp` can yield two types of values:
+ // 1. Not results of `IfOp`:
+ // Preserve them in the new `WarpOp`.
+ // Collect their yield index.
+ // 2. Results of `IfOp`:
+ // They are not part of the new `WarpOp` results.
+ // Map current warp's yield operand index to `IfOp` result idx.
+ SmallVector<Value> nonIfYieldValues;
+ SmallVector<unsigned> nonIfYieldIndices;
+ llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
+ llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+ const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
+ if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
+ nonIfYieldValues.push_back(yieldOperand.get());
+ nonIfYieldIndices.push_back(yieldOperandIdx);
+ continue;
+ }
+ OpResult ifResult = cast<OpResult>(yieldOperand.get());
+ const unsigned ifResultIdx = ifResult.getResultNumber();
+ ifResultMapping[yieldOperandIdx] = ifResultIdx;
+ // If this `ifOp` result is vector type and it is yielded by the
+ // `WarpOp`, we keep track the distributed type for this result.
+ if (!isa<VectorType>(ifResult.getType()))
+ continue;
+ VectorType distType =
+ cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
+ ifResultDistTypes[ifResultIdx] = distType;
+ }
+
+ // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
+ // them
+ auto getEscapingValues = [&](Region &branch,
----------------
akroviakov wrote:
Done
https://github.com/llvm/llvm-project/pull/157119
More information about the Mlir-commits
mailing list