[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)

Jianhui Li llvmlistbot at llvm.org
Tue Dec 9 11:04:56 PST 2025


================
@@ -1152,64 +1152,232 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
-/// Pattern for lowering vector.multi_reduction op to subgroup level.
-/// Current limitation: the sg_layout in the reduced dimension being 1
-/// so that reduction is local to subgroup & no cross-subgroup communication is
-/// needed.
-/// TODO: Add cases to handle more general situations which require SLM access.
+// This pattern transforms vector.multi_dim_reduction ops to work at subgroup
+// level.
 struct WgToSgMultiDimReductionOp
     : public OpConversionPattern<vector::MultiDimReductionOp> {
   using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
     VectorType srcType = op.getSourceVectorType();
     VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
     if (!dstType)
       return failure();
 
-    auto srcShape = srcType.getShape();
+    auto originalSrcShape = srcType.getShape();
     xegpu::DistributeLayoutAttr layout =
         xegpu::getDistributeLayoutAttr(op.getResult());
+
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
     auto reductionDims = llvm::to_vector(op.getReductionDims());
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "Only single dimension reduction is supported");
+
+    // Get sg_layout and sg_data from the parent layout
+    SmallVector<int64_t> sgLayout;
+    SmallVector<int64_t> sgData;
+    if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+      sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
+      sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
+    } else
+      return rewriter.notifyMatchFailure(
+          op, "Reduction should have SliceAttr layout");
+
+    Type elemTy = dstType.getElementType();
+
+    // Step 1: perform local subgroup reductions with ZERO accumulator
+    SmallVector<Value> localReductions;
+    auto sources = adaptor.getSource();
+    auto accs = adaptor.getAcc();
+
+    SmallVector<Value> expandedAccs;
+    if (accs.size() == 1 && sources.size() > 1) {
+      for (size_t i = 0; i < sources.size(); ++i)
+        expandedAccs.push_back(accs[0]);
+    } else
+      expandedAccs = llvm::to_vector(accs);
+
+    SmallVector<int64_t> sgShape =
+        getSgShapeAndCount(originalSrcShape, layout).first;
+    VectorType newDstType = VectorType::get({sgShape}, elemTy);
+    for (auto [sgSrc, sgAcc] : llvm::zip(sources, expandedAccs)) {
+      // Create ZERO accumulator for local reduction
+      auto zeroLocalAcc = arith::ConstantOp::create(
+          rewriter, loc, newDstType,
+          DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy)));
+      // Local reduction with ZERO accumulator
+      auto localReduce = vector::MultiDimReductionOp::create(
+          rewriter, loc, newDstType, op.getKind(), sgSrc,
+          zeroLocalAcc.getResult(), reductionDims);
+      localReductions.push_back(localReduce.getResult());
+    }
 
-    SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
-                                        .getParent()
-                                        .getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
-                                      .getParent()
-                                      .getEffectiveSgDataAsInt();
-
-    // Check that the sgLayout in the reduced dimension is 1 and
-    // each sg gets the entire slice to reduce.
-    for (int64_t dim : reductionDims) {
-      if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
-        return rewriter.notifyMatchFailure(
-            op,
-            "sgLayout in each reduced dimension must be 1 and sgData in the "
-            "reduced dim must match srcShape in that dim");
+    // Check if cross-subgroup reduction is needed
+    int64_t reductionDim = reductionDims[0];
+    bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1);
+
+    // If no cross-subgroup reduction needed, add accumulator and return
----------------
Jianhui-Li wrote:

The code could use some helper functions so the main functions becomes shorter. 

https://github.com/llvm/llvm-project/pull/170936


More information about the Mlir-commits mailing list