[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `scf.if` (PR #157119)
Artem Kroviakov
llvmlistbot at llvm.org
Wed Sep 10 01:19:17 PDT 2025
================
@@ -371,6 +371,36 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
return targetType;
}
+/// Given a warpOp that contains ops with regions, the corresponding op's
+/// "inner" region and the distributionMapFn, get all values used by the op's
+/// region that are defined within the warpOp. Return the set of values, their
+/// types and their distributed types.
+std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
+ SmallVector<Type>>
+getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
+ DistributionMapFn distributionMapFn) {
+ llvm::SmallSetVector<Value, 32> escapingValues;
+ SmallVector<Type> escapingValueTypes;
+ SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
+ if (innerRegion.empty())
+ return {escapingValues, escapingValueTypes, escapingValueDistTypes};
+ mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
+ Operation *parent = operand->get().getParentRegion()->getParentOp();
+ if (warpOp->isAncestor(parent)) {
+ if (!escapingValues.insert(operand->get()))
+ return;
+ Type distType = operand->get().getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(operand->get());
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ escapingValueTypes.push_back(operand->get().getType());
+ escapingValueDistTypes.push_back(distType);
+ }
+ });
+ return {escapingValues, escapingValueTypes, escapingValueDistTypes};
----------------
akroviakov wrote:
Added `std::move` to avoid copy construction for tuple elements. At a call, the binding should utilize move implicitly.
https://github.com/llvm/llvm-project/pull/157119
More information about the Mlir-commits
mailing list