[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)
Jianhui Li
llvmlistbot at llvm.org
Tue Dec 9 11:04:10 PST 2025
================
@@ -1152,64 +1152,232 @@ struct WgToSgVectorShapeCastOp
}
};
-/// Pattern for lowering vector.multi_reduction op to subgroup level.
-/// Current limitation: the sg_layout in the reduced dimension being 1
-/// so that reduction is local to subgroup & no cross-subgroup communication is
-/// needed.
-/// TODO: Add cases to handle more general situations which require SLM access.
+// This pattern transforms vector.multi_dim_reduction ops to work at subgroup
+// level.
struct WgToSgMultiDimReductionOp
: public OpConversionPattern<vector::MultiDimReductionOp> {
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
VectorType srcType = op.getSourceVectorType();
VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
if (!dstType)
return failure();
- auto srcShape = srcType.getShape();
+ auto originalSrcShape = srcType.getShape();
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());
+
if (!layout || !layout.isForWorkgroup())
return failure();
auto reductionDims = llvm::to_vector(op.getReductionDims());
+ if (reductionDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Only single dimension reduction is supported");
+
+ // Get sg_layout and sg_data from the parent layout
+ SmallVector<int64_t> sgLayout;
+ SmallVector<int64_t> sgData;
+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+ sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
+ sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
+ } else
+ return rewriter.notifyMatchFailure(
+ op, "Reduction should have SliceAttr layout");
+
+ Type elemTy = dstType.getElementType();
+
+ // Step 1: perform local subgroup reductions with ZERO accumulator
+ SmallVector<Value> localReductions;
+ auto sources = adaptor.getSource();
+ auto accs = adaptor.getAcc();
+
+ SmallVector<Value> expandedAccs;
+ if (accs.size() == 1 && sources.size() > 1) {
+ for (size_t i = 0; i < sources.size(); ++i)
+ expandedAccs.push_back(accs[0]);
+ } else
+ expandedAccs = llvm::to_vector(accs);
+
+ SmallVector<int64_t> sgShape =
+ getSgShapeAndCount(originalSrcShape, layout).first;
+ VectorType newDstType = VectorType::get({sgShape}, elemTy);
+ for (auto [sgSrc, sgAcc] : llvm::zip(sources, expandedAccs)) {
+ // Create ZERO accumulator for local reduction
+ auto zeroLocalAcc = arith::ConstantOp::create(
+ rewriter, loc, newDstType,
+ DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy)));
+ // Local reduction with ZERO accumulator
+ auto localReduce = vector::MultiDimReductionOp::create(
+ rewriter, loc, newDstType, op.getKind(), sgSrc,
+ zeroLocalAcc.getResult(), reductionDims);
+ localReductions.push_back(localReduce.getResult());
+ }
- SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
- .getParent()
- .getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
- .getParent()
- .getEffectiveSgDataAsInt();
-
- // Check that the sgLayout in the reduced dimension is 1 and
- // each sg gets the entire slice to reduce.
- for (int64_t dim : reductionDims) {
- if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
- return rewriter.notifyMatchFailure(
- op,
- "sgLayout in each reduced dimension must be 1 and sgData in the "
- "reduced dim must match srcShape in that dim");
+ // Check if cross-subgroup reduction is needed
+ int64_t reductionDim = reductionDims[0];
+ bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1);
+
+ // If no cross-subgroup reduction needed, add accumulator and return
+ if (!needsCrossSubgroupReduction) {
+ SmallVector<Value> results;
+ for (auto localResult : localReductions) {
+ auto finalResult = arith::AddFOp::create(rewriter, loc, localResult,
+ adaptor.getAcc()[0]);
+ if (auto defOp = finalResult.getResult().getDefiningOp())
+ xegpu::setDistributeLayoutAttr(defOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ results.push_back(finalResult.getResult());
+ }
+ rewriter.replaceOpWithMultiple(op, {results});
+ return success();
}
- SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+ // Step 2: Cross-subgroup reduction using SLM
- VectorType newDstType =
- VectorType::get({sgShape}, dstType.getElementType());
+ // Calculate total elements in local result
+ int64_t localElements = computeProduct(sgShape);
- SmallVector<Value> newReductions;
- for (auto sgSrc : adaptor.getSource()) {
- auto newOp = vector::MultiDimReductionOp::create(
- rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
- adaptor.getAcc()[0], op.getReductionDims());
- xegpu::setDistributeLayoutAttr(newOp->getResult(0),
- layout.dropSgLayoutAndData());
- newReductions.push_back(newOp.getResult());
+ // Shape cast for SLM storage - store as [1, localElements]
+ SmallVector<int64_t> storeShape2D = {1, localElements};
+ VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
+ auto storeShapeCast = vector::ShapeCastOp::create(
+ rewriter, loc, storeType2D, localReductions[0]);
+ Value storeData = storeShapeCast.getResult();
+
+ // Calculate SLM shape
+ int64_t totalReductionSubgroups =
+ sgLayout[static_cast<size_t>(reductionDims[0])];
+
+ // Total result elements across all subgroups in non-reduction dimensions
+ int64_t totalResultElements = localElements;
+ for (size_t i = 0; i < sgLayout.size(); ++i) {
+ if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i)))
+ totalResultElements *= sgLayout[i];
+ }
+
+ SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
+ totalResultElements};
+
+ // Allocate SLM
+ auto bitWidth = elemTy.getIntOrFloatBitWidth();
+ auto bytesPerElement = bitWidth / 8;
+ int64_t slmElements = slmShape2D[0] * slmShape2D[1];
+ auto slmSize = slmElements * bytesPerElement;
+ auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+ auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+ auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
+ slmShape2D, elemTy, nullptr);
+ auto memDesc =
+ xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+ // Step 4: Store local results to SLM
+ auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+ rewriter.getIndexType(), nullptr);
+
+ // Convert sgLayout to Values for delinearizeIndex
+ SmallVector<Value> sgLayoutValues;
+ for (int64_t dim : sgLayout)
+ sgLayoutValues.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, dim));
+
+ auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
+ sgLayoutValues);
+ if (failed(sgIdsResult))
+ return failure();
+ SmallVector<Value> sgIds = *sgIdsResult;
+
+ // Row offset is simply the subgroup ID along the reduction dimension
+ Value rowOffset = sgIds[reductionDim];
+
+ // Column offset: linearize all non-reduction dimensions and multiply by
+ // localElements
+ Value colOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ int64_t currentStride = 1;
+ for (size_t i = 0; i < sgLayout.size(); ++i) {
+ if (static_cast<int64_t>(i) != reductionDim) { // Skip reduction dimension
+ Value dimVal = sgIds[i];
+ Value strideVal =
+ arith::ConstantIndexOp::create(rewriter, loc, currentStride);
+ Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+ colOffset = arith::AddIOp::create(rewriter, loc, colOffset, term);
+ currentStride *= sgLayout[i];
+ }
}
+ Value localElementsVal =
+ arith::ConstantIndexOp::create(rewriter, loc, localElements);
+ colOffset =
+ arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
+
+ SmallVector<OpFoldResult> storeOffsets2D = {rowOffset, colOffset};
+
+ xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
+ storeOffsets2D, /*layout=*/nullptr);
+
+ gpu::BarrierOp::create(rewriter, loc);
----------------
Jianhui-Li wrote:
To sync producer and consumer sg for data, both barrier and fence are needed.
https://github.com/llvm/llvm-project/pull/170936
More information about the Mlir-commits
mailing list