[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)

Jianhui Li llvmlistbot at llvm.org
Fri Dec 19 22:17:52 PST 2025


================
@@ -1161,64 +1161,338 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
-/// Pattern for lowering vector.multi_reduction op to subgroup level.
-/// Current limitation: the sg_layout in the reduced dimension being 1
-/// so that reduction is local to subgroup & no cross-subgroup communication is
-/// needed.
-/// TODO: Add cases to handle more general situations which require SLM access.
+/// This function converts multi-dimensional subgroup indices into a single
+/// linear offset. It's used to calculate memory offsets in SLM for
+/// cross-subgroup reduction coordination.
+///
+/// Parameters:
+/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
+/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
+///   z dims)
+/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
+///   4x8x2 subgroups)
+///
+/// It uses row-major linearization formula:
+///    offset = sum(sgIds[dim] * stride[dim])
+///    where stride[dim] = product of all sgLayout sizes in dimensions after
+///    'dim'
+///
+/// Example:
+/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
+/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
+/// - Calculation:
+///   * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
+///   * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
+///   * linearizedOffset = 1 + 4 = 5
+///
+/// This gives us a unique linear index for each combination of subgroup
+/// positions in the specified dimensions, which is used for SLM row/column
+/// addressing.
+static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
+                                      Location loc, ArrayRef<Value> sgIds,
+                                      ArrayRef<int64_t> dims,
+                                      ArrayRef<int64_t> sgLayout) {
+  Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  int64_t stride = 1;
+
+  for (int64_t dim : dims) {
+    Value dimVal = sgIds[dim];
+    Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+    Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+    linearizedOffset =
+        arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
+    stride *= sgLayout[dim];
+  }
+
+  return linearizedOffset;
+}
+
+// Helper function to create the appropriate binary operation based on reduction
+// kind
+static Value reductionOpKind(ConversionPatternRewriter &rewriter, Location loc,
+                             vector::CombiningKind kind, Value lhs, Value rhs) {
+  Type elemType = getElementTypeOrSelf(lhs.getType());
+  bool isFloat = isa<FloatType>(elemType);
+
+  switch (kind) {
+  case vector::CombiningKind::ADD:
+    return isFloat ? arith::AddFOp::create(rewriter, loc, lhs, rhs).getResult()
+                   : arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MUL:
+    return isFloat ? arith::MulFOp::create(rewriter, loc, lhs, rhs).getResult()
+                   : arith::MulIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MINSI:
+    return arith::MinSIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MINUI:
+    return arith::MinUIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MAXSI:
+    return arith::MaxSIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MAXUI:
+    return arith::MaxUIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::AND:
+    return arith::AndIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::OR:
+    return arith::OrIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::XOR:
+    return arith::XOrIOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MINNUMF:
+    return arith::MinNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MAXNUMF:
+    return arith::MaxNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MINIMUMF:
+    return arith::MinimumFOp::create(rewriter, loc, lhs, rhs).getResult();
+  case vector::CombiningKind::MAXIMUMF:
+    return arith::MaximumFOp::create(rewriter, loc, lhs, rhs).getResult();
+  }
+  llvm_unreachable("unsupported OpKind");
+}
+
+/// This pattern transforms vector.multi_dim_reduction operations from
+/// workgroup-level to subgroup-level execution with support for multiple
+/// reduction dimensions.
+///
+/// Steps include:
+/// 1. LOCAL REDUCTION :
+///    - Each subgroup performs local reduction on its data slice
+///    - Uses ZERO accumulator to avoid double-counting during cross-subgroup
+///    phase
+///
+/// 2. CROSS-SUBGROUP :
+///    - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
+///    reduction dims)
+///    - If not needed, adds original accumulator and returns local results
+///
+/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
+///    a) SLM Layout Design:
+///       - Rows: subgroups participating in reduction (product of sg_layout in
+///       reduction dims)
+///       - Cols: total result elements across non-reduction dimensions
+///
+///    b) Store Phase:
+///       - Each subgroup stores its local reduction result to SLM
+///       - Row offset: linearized index of subgroup in reduction dimensions
+///       - Col offset: linearized index of subgroup in non-reduction dimensions
+///
+///    c) Load and Final Reduction Phase:
+///       - Each subgroup loads a column of data (all reduction participants for
+///       its position)
+///       - Performs final reduction along the loaded dimension
+///       - Adds original accumulator to get final result
+///
 struct WgToSgMultiDimReductionOp
     : public OpConversionPattern<vector::MultiDimReductionOp> {
   using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
     VectorType srcType = op.getSourceVectorType();
     VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
     if (!dstType)
       return failure();
 
-    auto srcShape = srcType.getShape();
+    auto originalSrcShape = srcType.getShape();
     xegpu::DistributeLayoutAttr layout =
         xegpu::getDistributeLayoutAttr(op.getResult());
+
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
     auto reductionDims = llvm::to_vector(op.getReductionDims());
 
-    SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
-                                        .getParent()
-                                        .getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
-                                      .getParent()
-                                      .getEffectiveSgDataAsInt();
-
-    // Check that the sgLayout in the reduced dimension is 1 and
-    // each sg gets the entire slice to reduce.
-    for (int64_t dim : reductionDims) {
-      if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
-        return rewriter.notifyMatchFailure(
-            op,
-            "sgLayout in each reduced dimension must be 1 and sgData in the "
-            "reduced dim must match srcShape in that dim");
+    // Get sg_layout and sg_data from the parent layout
+    SmallVector<int64_t> sgLayout;
+    SmallVector<int64_t> sgData;
+    if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+      sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
+      sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
+    } else
+      return rewriter.notifyMatchFailure(
+          op, "Reduction should have SliceAttr layout");
+
+    Type elemTy = dstType.getElementType();
+
+    // Step 1: perform local subgroup reductions with ZERO accumulator
+    SmallVector<Value> localReductions;
+    SmallVector<int64_t> sgShape =
+        getSgShapeAndCount(originalSrcShape, layout).first;
+    VectorType newDstType = VectorType::get(sgShape, elemTy);
+    for (auto sgSrc : adaptor.getSource()) {
+      // Create ZERO accumulator for local reduction
+      auto zeroLocalAcc = arith::ConstantOp::create(
+          rewriter, loc, newDstType,
+          DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy)));
+      // Local reduction with ZERO accumulator
+      auto localReduce = vector::MultiDimReductionOp::create(
+          rewriter, loc, newDstType, op.getKind(), sgSrc,
+          zeroLocalAcc.getResult(), reductionDims);
+      localReductions.push_back(localReduce.getResult());
     }
 
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+    // Check if cross-subgroup reduction is needed for any reduction dimension
+    bool needsCrossSubgroupReduction = false;
+    SmallVector<int64_t> crossSgReductionDims;
+    for (int64_t reductionDim : reductionDims) {
+      if (sgLayout[reductionDim] > 1) {
+        needsCrossSubgroupReduction = true;
+        crossSgReductionDims.push_back(reductionDim);
+      }
+    }
 
-    VectorType newDstType =
-        VectorType::get({sgShape}, dstType.getElementType());
+    // If no cross-subgroup reduction needed, add accumulator and return
+    if (!needsCrossSubgroupReduction) {
+      SmallVector<Value> results;
+      for (auto localResult : localReductions) {
+        auto finalResult = reductionOpKind(rewriter, loc, op.getKind(),
+                                           localResult, adaptor.getAcc()[0]);
+        if (auto defOp = finalResult.getDefiningOp())
+          xegpu::setDistributeLayoutAttr(defOp->getResult(0),
+                                         layout.dropSgLayoutAndData());
+        results.push_back(finalResult);
+      }
+      rewriter.replaceOpWithMultiple(op, {results});
+      return success();
+    }
 
-    SmallVector<Value> newReductions;
-    for (auto sgSrc : adaptor.getSource()) {
-      auto newOp = vector::MultiDimReductionOp::create(
-          rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
-          adaptor.getAcc()[0], op.getReductionDims());
-      xegpu::setDistributeLayoutAttr(newOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
-      newReductions.push_back(newOp.getResult());
+    // Step 2: cross-subgroup reduction using SLM
+
+    // Calculate total elements in local result
+    int64_t localElements = computeProduct(sgShape);
+
+    // Shape cast for SLM storage - store as [1, localElements]
+    SmallVector<int64_t> storeShape2D = {1, localElements};
+    VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
+    auto storeShapeCast = vector::ShapeCastOp::create(
+        rewriter, loc, storeType2D, localReductions[0]);
+    Value storeData = storeShapeCast.getResult();
+
+    // Calculate SLM shape - rows for sg's in reduction dims, cols for total
+    // result elements across all subgroups in non-reduction dimensions
+    int64_t totalReductionSubgroups = 1;
+    for (int64_t dim : crossSgReductionDims) {
+      totalReductionSubgroups *= sgLayout[dim];
+    }
+
+    // Total result elements across all subgroups in non-reduction dimensions
+    int64_t totalResultElements =
+        localElements * computeProduct(sgLayout) / totalReductionSubgroups;
+
+    SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
+                                       totalResultElements};
+
+    // Allocate SLM
+    auto bitWidth = elemTy.getIntOrFloatBitWidth();
+    auto bytesPerElement = bitWidth / 8;
+    int64_t slmElements = slmShape2D[0] * slmShape2D[1];
+    auto slmSize = slmElements * bytesPerElement;
+    auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+    auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+    auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
+                                               slmShape2D, elemTy, nullptr);
+    auto memDesc =
+        xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+    // Step 4: Store local results to SLM
+    auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+                                          rewriter.getIndexType(), nullptr);
+
+    // Convert sgLayout to Values for delinearizeIndex
+    SmallVector<Value> sgLayoutValues;
+    for (int64_t dim : sgLayout)
+      sgLayoutValues.push_back(
+          arith::ConstantIndexOp::create(rewriter, loc, dim));
+
+    auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
+                                                sgLayoutValues);
+    if (failed(sgIdsResult))
+      return failure();
+    SmallVector<Value> sgIds = *sgIdsResult;
+
+    // Row offset: linearize reduction dimension indices
+    Value rowOffsetStore = linearizeSubgroupIndices(
+        rewriter, loc, sgIds, crossSgReductionDims, sgLayout);
+
+    // Column offset: linearize non-reduction dimension indices
+    SmallVector<int64_t> nonReductionDims;
+    for (size_t i = 0; i < sgLayout.size(); ++i) {
+      if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i))) {
+        nonReductionDims.push_back(static_cast<int64_t>(i));
+      }
+    }
+
+    Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds,
+                                               nonReductionDims, sgLayout);
+
+    Value localElementsVal =
+        arith::ConstantIndexOp::create(rewriter, loc, localElements);
+    colOffset =
+        arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
+
+    SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
+
+    xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
----------------
Jianhui-Li wrote:

storeMatrix is an anchor op and serves the starting point of backward layout propagation so must set layout. 

https://github.com/llvm/llvm-project/pull/170936


More information about the Mlir-commits mailing list