[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)
Jianhui Li
llvmlistbot at llvm.org
Fri Dec 19 20:55:04 PST 2025
================
@@ -1161,64 +1161,338 @@ struct WgToSgVectorShapeCastOp
}
};
-/// Pattern for lowering vector.multi_reduction op to subgroup level.
-/// Current limitation: the sg_layout in the reduced dimension being 1
-/// so that reduction is local to subgroup & no cross-subgroup communication is
-/// needed.
-/// TODO: Add cases to handle more general situations which require SLM access.
+/// This function converts multi-dimensional subgroup indices into a single
+/// linear offset. It's used to calculate memory offsets in SLM for
+/// cross-subgroup reduction coordination.
+///
+/// Parameters:
+/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
+/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
+/// z dims)
+/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
+/// 4x8x2 subgroups)
+///
+/// It uses row-major linearization formula:
+/// offset = sum(sgIds[dim] * stride[dim])
+/// where stride[dim] = product of all sgLayout sizes in dimensions after
+/// 'dim'
+///
+/// Example:
+/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
+/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
+/// - Calculation:
+/// * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
+/// * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
+/// * linearizedOffset = 1 + 4 = 5
+///
+/// This gives us a unique linear index for each combination of subgroup
+/// positions in the specified dimensions, which is used for SLM row/column
+/// addressing.
+static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
+ Location loc, ArrayRef<Value> sgIds,
+ ArrayRef<int64_t> dims,
+ ArrayRef<int64_t> sgLayout) {
+ Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ int64_t stride = 1;
+
+ for (int64_t dim : dims) {
+ Value dimVal = sgIds[dim];
+ Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+ Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+ linearizedOffset =
+ arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
+ stride *= sgLayout[dim];
+ }
+
+ return linearizedOffset;
+}
+
+// Helper function to create the appropriate binary operation based on reduction
+// kind
+static Value reductionOpKind(ConversionPatternRewriter &rewriter, Location loc,
+ vector::CombiningKind kind, Value lhs, Value rhs) {
+ Type elemType = getElementTypeOrSelf(lhs.getType());
+ bool isFloat = isa<FloatType>(elemType);
+
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ return isFloat ? arith::AddFOp::create(rewriter, loc, lhs, rhs).getResult()
+ : arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MUL:
+ return isFloat ? arith::MulFOp::create(rewriter, loc, lhs, rhs).getResult()
+ : arith::MulIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINSI:
+ return arith::MinSIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINUI:
+ return arith::MinUIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXSI:
+ return arith::MaxSIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXUI:
+ return arith::MaxUIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::AND:
+ return arith::AndIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::OR:
+ return arith::OrIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::XOR:
+ return arith::XOrIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINNUMF:
+ return arith::MinNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXNUMF:
+ return arith::MaxNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINIMUMF:
+ return arith::MinimumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXIMUMF:
+ return arith::MaximumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ }
+ llvm_unreachable("unsupported OpKind");
+}
+
+/// This pattern transforms vector.multi_dim_reduction operations from
+/// workgroup-level to subgroup-level execution with support for multiple
+/// reduction dimensions.
+///
+/// Steps include:
+/// 1. LOCAL REDUCTION :
+/// - Each subgroup performs local reduction on its data slice
+/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
+/// phase
+///
+/// 2. CROSS-SUBGROUP :
+/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
+/// reduction dims)
+/// - If not needed, adds original accumulator and returns local results
+///
+/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
+/// a) SLM Layout Design:
+/// - Rows: subgroups participating in reduction (product of sg_layout in
+/// reduction dims)
+/// - Cols: total result elements across non-reduction dimensions
+///
+/// b) Store Phase:
+/// - Each subgroup stores its local reduction result to SLM
+/// - Row offset: linearized index of subgroup in reduction dimensions
+/// - Col offset: linearized index of subgroup in non-reduction dimensions
+///
+/// c) Load and Final Reduction Phase:
+/// - Each subgroup loads a column of data (all reduction participants for
+/// its position)
+/// - Performs final reduction along the loaded dimension
+/// - Adds original accumulator to get final result
+///
struct WgToSgMultiDimReductionOp
: public OpConversionPattern<vector::MultiDimReductionOp> {
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
VectorType srcType = op.getSourceVectorType();
VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
if (!dstType)
return failure();
- auto srcShape = srcType.getShape();
+ auto originalSrcShape = srcType.getShape();
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());
+
if (!layout || !layout.isForWorkgroup())
return failure();
auto reductionDims = llvm::to_vector(op.getReductionDims());
- SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
- .getParent()
- .getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
- .getParent()
- .getEffectiveSgDataAsInt();
-
- // Check that the sgLayout in the reduced dimension is 1 and
- // each sg gets the entire slice to reduce.
- for (int64_t dim : reductionDims) {
- if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
- return rewriter.notifyMatchFailure(
- op,
- "sgLayout in each reduced dimension must be 1 and sgData in the "
- "reduced dim must match srcShape in that dim");
+ // Get sg_layout and sg_data from the parent layout
+ SmallVector<int64_t> sgLayout;
+ SmallVector<int64_t> sgData;
+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+ sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
+ sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
+ } else
+ return rewriter.notifyMatchFailure(
+ op, "Reduction should have SliceAttr layout");
+
+ Type elemTy = dstType.getElementType();
+
+ // Step 1: perform local subgroup reductions with ZERO accumulator
+ SmallVector<Value> localReductions;
+ SmallVector<int64_t> sgShape =
+ getSgShapeAndCount(originalSrcShape, layout).first;
+ VectorType newDstType = VectorType::get(sgShape, elemTy);
+ for (auto sgSrc : adaptor.getSource()) {
+ // Create ZERO accumulator for local reduction
+ auto zeroLocalAcc = arith::ConstantOp::create(
+ rewriter, loc, newDstType,
+ DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy)));
+ // Local reduction with ZERO accumulator
+ auto localReduce = vector::MultiDimReductionOp::create(
+ rewriter, loc, newDstType, op.getKind(), sgSrc,
+ zeroLocalAcc.getResult(), reductionDims);
+ localReductions.push_back(localReduce.getResult());
}
- SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+ // Check if cross-subgroup reduction is needed for any reduction dimension
+ bool needsCrossSubgroupReduction = false;
+ SmallVector<int64_t> crossSgReductionDims;
+ for (int64_t reductionDim : reductionDims) {
+ if (sgLayout[reductionDim] > 1) {
+ needsCrossSubgroupReduction = true;
----------------
Jianhui-Li wrote:
it is not necessary true if localReductions along this dimension is already unit size.
https://github.com/llvm/llvm-project/pull/170936
More information about the Mlir-commits
mailing list