[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)
Nishant Patel
llvmlistbot at llvm.org
Tue Jan 13 11:40:44 PST 2026
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/170936
>From 91de10671a4eb3d2ac8876bd45e347d2d62d7015 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 4 Dec 2025 00:08:26 +0000
Subject: [PATCH 1/6] Add support for cross-subgroup reduction from wg to sg
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 232 +++++++++++++++---
.../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 4 +-
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 85 +++++++
3 files changed, 287 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 48bd0662b03ff..173a798fa9021 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1152,11 +1152,8 @@ struct WgToSgVectorShapeCastOp
}
};
-/// Pattern for lowering vector.multi_reduction op to subgroup level.
-/// Current limitation: the sg_layout in the reduced dimension being 1
-/// so that reduction is local to subgroup & no cross-subgroup communication is
-/// needed.
-/// TODO: Add cases to handle more general situations which require SLM access.
+// This pattern transforms vector.multi_dim_reduction ops to work at subgroup
+// level.
struct WgToSgMultiDimReductionOp
: public OpConversionPattern<vector::MultiDimReductionOp> {
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
@@ -1164,52 +1161,221 @@ struct WgToSgMultiDimReductionOp
LogicalResult
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
VectorType srcType = op.getSourceVectorType();
VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
if (!dstType)
return failure();
- auto srcShape = srcType.getShape();
+ auto originalSrcShape = srcType.getShape();
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());
+
if (!layout || !layout.isForWorkgroup())
return failure();
auto reductionDims = llvm::to_vector(op.getReductionDims());
+ if (reductionDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Only single dimension reduction is supported");
+
+ // Get sg_layout and sg_data from the parent layout
+ SmallVector<int64_t> sgLayout;
+ SmallVector<int64_t> sgData;
+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+ sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
+ sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
+ } else
+ return rewriter.notifyMatchFailure(
+ op, "Reduction should have SliceAttr layout");
+
+ Type elemTy = dstType.getElementType();
+
+ // Step 1: perform local subgroup reductions with ZERO accumulator
+ SmallVector<Value> localReductions;
+ auto sources = adaptor.getSource();
+ auto accs = adaptor.getAcc();
+
+ SmallVector<Value> expandedAccs;
+ if (accs.size() == 1 && sources.size() > 1) {
+ for (size_t i = 0; i < sources.size(); ++i)
+ expandedAccs.push_back(accs[0]);
+ } else
+ expandedAccs = llvm::to_vector(accs);
+
+ SmallVector<int64_t> sgShape =
+ getSgShapeAndCount(originalSrcShape, layout).first;
+ VectorType newDstType = VectorType::get({sgShape}, elemTy);
+ for (auto [sgSrc, sgAcc] : llvm::zip(sources, expandedAccs)) {
+ // Create ZERO accumulator for local reduction
+ auto zeroLocalAcc = arith::ConstantOp::create(
+ rewriter, loc, newDstType,
+ DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy)));
+ // Local reduction with ZERO accumulator
+ auto localReduce = vector::MultiDimReductionOp::create(
+ rewriter, loc, newDstType, op.getKind(), sgSrc,
+ zeroLocalAcc.getResult(), reductionDims);
+ localReductions.push_back(localReduce.getResult());
+ }
- SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
- .getParent()
- .getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
- .getParent()
- .getEffectiveSgDataAsInt();
-
- // Check that the sgLayout in the reduced dimension is 1 and
- // each sg gets the entire slice to reduce.
- for (int64_t dim : reductionDims) {
- if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
- return rewriter.notifyMatchFailure(
- op,
- "sgLayout in each reduced dimension must be 1 and sgData in the "
- "reduced dim must match srcShape in that dim");
+ // Check if cross-subgroup reduction is needed
+ int64_t reductionDim = reductionDims[0];
+ bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1);
+
+ // If no cross-subgroup reduction needed, add accumulator and return
+ if (!needsCrossSubgroupReduction) {
+ SmallVector<Value> results;
+ for (auto localResult : localReductions) {
+ auto finalResult = arith::AddFOp::create(rewriter, loc, localResult,
+ adaptor.getAcc()[0]);
+ if (auto defOp = finalResult.getResult().getDefiningOp())
+ xegpu::setDistributeLayoutAttr(defOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ results.push_back(finalResult.getResult());
+ }
+ rewriter.replaceOpWithMultiple(op, {results});
+ return success();
}
- SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+ // Step 2: Cross-subgroup reduction using SLM
- VectorType newDstType =
- VectorType::get({sgShape}, dstType.getElementType());
+ // Calculate total elements in local result
+ int64_t localElements = computeProduct(sgShape);
- SmallVector<Value> newReductions;
- for (auto sgSrc : adaptor.getSource()) {
- auto newOp = vector::MultiDimReductionOp::create(
- rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
- adaptor.getAcc()[0], op.getReductionDims());
- xegpu::setDistributeLayoutAttr(newOp->getResult(0),
- layout.dropSgLayoutAndData());
- newReductions.push_back(newOp.getResult());
+ // Shape cast for SLM storage - store as [1, localElements]
+ SmallVector<int64_t> storeShape2D = {1, localElements};
+ VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
+ auto storeShapeCast = vector::ShapeCastOp::create(
+ rewriter, loc, storeType2D, localReductions[0]);
+ Value storeData = storeShapeCast.getResult();
+
+ // Calculate SLM shape
+ int64_t totalReductionSubgroups =
+ sgLayout[static_cast<size_t>(reductionDims[0])];
+
+ // Total result elements across all subgroups in non-reduction dimensions
+ int64_t totalResultElements = localElements;
+ for (size_t i = 0; i < sgLayout.size(); ++i) {
+ if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i)))
+ totalResultElements *= sgLayout[i];
+ }
+
+ SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
+ totalResultElements};
+
+ // Allocate SLM
+ auto bitWidth = elemTy.getIntOrFloatBitWidth();
+ auto bytesPerElement = bitWidth / 8;
+ int64_t slmElements = slmShape2D[0] * slmShape2D[1];
+ auto slmSize = slmElements * bytesPerElement;
+ auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+ auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+ auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
+ slmShape2D, elemTy, nullptr);
+ auto memDesc =
+ xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+ // Step 4: Store local results to SLM
+ auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+ rewriter.getIndexType(), nullptr);
+
+ // Convert sgLayout to Values for delinearizeIndex
+ SmallVector<Value> sgLayoutValues;
+ for (int64_t dim : sgLayout)
+ sgLayoutValues.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, dim));
+
+ auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
+ sgLayoutValues);
+ if (failed(sgIdsResult))
+ return failure();
+ SmallVector<Value> sgIds = *sgIdsResult;
+
+ // Row offset is simply the subgroup ID along the reduction dimension
+ Value rowOffset = sgIds[reductionDim];
+
+ // Column offset: linearize all non-reduction dimensions and multiply by
+ // localElements
+ Value colOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ int64_t currentStride = 1;
+ for (size_t i = 0; i < sgLayout.size(); ++i) {
+ if (static_cast<int64_t>(i) != reductionDim) { // Skip reduction dimension
+ Value dimVal = sgIds[i];
+ Value strideVal =
+ arith::ConstantIndexOp::create(rewriter, loc, currentStride);
+ Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+ colOffset = arith::AddIOp::create(rewriter, loc, colOffset, term);
+ currentStride *= sgLayout[i];
+ }
+ }
+ Value localElementsVal =
+ arith::ConstantIndexOp::create(rewriter, loc, localElements);
+ colOffset =
+ arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
+
+ SmallVector<OpFoldResult> storeOffsets2D = {rowOffset, colOffset};
+
+ xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
+ storeOffsets2D, /*layout=*/nullptr);
+
+ gpu::BarrierOp::create(rewriter, loc);
+
+ // Step 5: Load from SLM for final reduction
+ SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
+ VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
+
+ // Load offsets - each subgroup loads its column based on non-reduction
+ // position
+ Value loadOffsetY = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value loadOffsetX = colOffset;
+
+ SmallVector<OpFoldResult> loadOffsets2D = {loadOffsetY, loadOffsetX};
+
+ auto loadOp = xegpu::LoadMatrixOp::create(
+ rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
+ /*layout=*/nullptr);
+
+ // Step 6: Perform final reduction with ZERO accumulator
+ SmallVector<int64_t> finalReductionDims = {0};
+ SmallVector<int64_t> finalResultShape = {localElements};
+ VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
+
+ // Create ZERO accumulator for final reduction
+ auto zeroFinalAcc = arith::ConstantOp::create(
+ rewriter, loc, finalResultType,
+ DenseElementsAttr::get(finalResultType, rewriter.getZeroAttr(elemTy)));
+
+ auto finalReduce = vector::MultiDimReductionOp::create(
+ rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
+ zeroFinalAcc.getResult(), finalReductionDims);
+
+ // Step 7: Add the original accumulator at the end
+ Value originalAcc = adaptor.getAcc()[0];
+ Value accToAdd = originalAcc;
+
+ // Handle shape mismatch by shape casting
+ if (originalAcc.getType() != finalReduce.getResult().getType()) {
+ auto originalAccType = cast<VectorType>(originalAcc.getType());
+ auto finalResultType =
+ cast<VectorType>(finalReduce.getResult().getType());
+
+ // If they have the same number of elements, just shape cast
+ if (originalAccType.getNumElements() == finalResultType.getNumElements())
+ auto shapeCast = vector::ShapeCastOp::create(
+ rewriter, loc, finalResultType, originalAcc);
+ accToAdd = shapeCast.getResult();
}
- rewriter.replaceOpWithMultiple(op, {newReductions});
+ auto finalResult =
+ arith::AddFOp::create(rewriter, loc, finalReduce.getResult(), accToAdd);
+
+ if (auto defOp = finalResult.getResult().getDefiningOp())
+ xegpu::setDistributeLayoutAttr(defOp->getResult(0),
+ layout.dropSgLayoutAndData());
+
+ rewriter.replaceOp(op, finalResult.getResult());
return success();
}
};
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 1cddccb5fbbd1..ff792d809a090 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -83,8 +83,10 @@ gpu.module @test_distribution {
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
-> vector<256x64xf32>
- // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32>
+ // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[C0:.*]] [1] : vector<16x64xf32> to vector<16xf32>
// CHECK-NOT: vector.multi_reduction
+ // CHECK-COUNT-2: arith.addf {{.*}}, {{.*}} : vector<16xf32>
+ // CHECK-NOT: arith.addf
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
: vector<256x64xf32> to vector<256xf32>
gpu.return
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 574b365443a0a..bbacc527984e7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,5 +1,10 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 32)>
+// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 32)>
+// CHECK-DAG: #map2 = affine_map<()[s0] -> (0)>
+// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -599,4 +604,84 @@ gpu.module @test_distribution {
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
gpu.return
}
+
+ // CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_1
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
+gpu.func @vector_reduce_cross_sg_dim_1(%src: memref<?xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x32xf32>
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32xindex>
+ // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32xi1>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+ // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+ // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [1] : vector<1x1x32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map2()[%[[SGID]]]
+ // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
+ // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
+ // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<32x32xf32> to vector<32xf32>
+ // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x32xf32> to vector<32xf32>
+ // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<32xf32>
+ %cst_3 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} dense<1.0> : vector<1x32xf32>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<0> : vector<1x32x32xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<true> : vector<1x32x32xi1>
+ %14 = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} : memref<?xf32>, vector<1x32x32xindex>, vector<1x32x32xi1> -> vector<1x32x32xf32>
+ %15 = vector.multi_reduction <add>, %14, %cst_3 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} [1] : vector<1x32x32xf32> to vector<1x32xf32>
+ // CHECK-DAG: gpu.return
+ gpu.return
+}
+
+ // CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_0
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<256x128xf32>)
+ gpu.func @vector_reduce_cross_sg_dim_0(%src: memref<256x128xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REM4:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[DIV4:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[REM8:.*]] = index.remu %[[DIV4]], %[[C8:.*]]
+ // CHECK-DAG: %[[MUL1:.*]] = index.mul %[[REM8]], %[[C32:.*]]
+ // CHECK-DAG: %[[MUL2:.*]] = index.mul %[[REM4]], %[[C32:.*]]
+ // CHECK-DAG: %[[REM256:.*]] = index.remu %[[MUL1]], %[[C256:.*]]
+ // CHECK-DAG: %[[REM128:.*]] = index.remu %[[MUL2]], %[[C128:.*]]
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[REM256]], %[[REM128]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32> -> vector<32x32xf32>
+ // CHECK-DAG: %[[CST_LOCAL:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_LOCAL]] [0] : vector<32x32xf32> to vector<32xf32>
+ // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
+ // CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map3()[%[[SGID2]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map4()[%[[SGID2]]]
+ // CHECK-DAG: %[[MUL_AFFINE:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL_AFFINE]] : index
+ // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD_OFFSET]], %[[C32:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
+ // CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
+ // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST:.*]] : vector<32xf32>
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<0.0> : vector<128xf32>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ -> vector<256x128xf32>
+ %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
+ : vector<256x128xf32> to vector<128xf32>
+ // CHECK-DAG: gpu.return
+ gpu.return
+ }
}
>From 8dda8890f042904cc1415eb9fcb13b7cea0a9637 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 10 Dec 2025 23:41:33 +0000
Subject: [PATCH 2/6] Support multi-dim reduce
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 215 +++++++++++++-----
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 152 +++++++++----
2 files changed, 264 insertions(+), 103 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 365e1ce1732f6..6b463eae2d528 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1152,8 +1152,125 @@ struct WgToSgVectorShapeCastOp
}
};
-// This pattern transforms vector.multi_dim_reduction ops to work at subgroup
-// level.
+/// This function converts multi-dimensional subgroup indices into a single
+/// linear offset. It's used to calculate memory offsets in SLM for
+/// cross-subgroup reduction coordination.
+///
+/// Parameters:
+/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
+/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
+/// z dims)
+/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
+/// 4x8x2 subgroups)
+///
+/// It uses row-major linearization formula:
+/// offset = sum(sgIds[dim] * stride[dim])
+/// where stride[dim] = product of all sgLayout sizes in dimensions after
+/// 'dim'
+///
+/// Example:
+/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
+/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
+/// - Calculation:
+/// * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
+/// * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
+/// * linearizedOffset = 1 + 4 = 5
+///
+/// This gives us a unique linear index for each combination of subgroup
+/// positions in the specified dimensions, which is used for SLM row/column
+/// addressing.
+static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
+ Location loc, ArrayRef<Value> sgIds,
+ ArrayRef<int64_t> dims,
+ ArrayRef<int64_t> sgLayout) {
+ Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ int64_t stride = 1;
+
+ for (int64_t dim : dims) {
+ Value dimVal = sgIds[dim];
+ Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
+ Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
+ linearizedOffset =
+ arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
+ stride *= sgLayout[dim];
+ }
+
+ return linearizedOffset;
+}
+
+// Helper function to create the appropriate binary operation based on reduction
+// kind
+static Value reductionOpKind(ConversionPatternRewriter &rewriter, Location loc,
+ vector::CombiningKind kind, Value lhs, Value rhs) {
+ Type elemType = getElementTypeOrSelf(lhs.getType());
+ bool isFloat = isa<FloatType>(elemType);
+
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ return isFloat ? arith::AddFOp::create(rewriter, loc, lhs, rhs).getResult()
+ : arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MUL:
+ return isFloat ? arith::MulFOp::create(rewriter, loc, lhs, rhs).getResult()
+ : arith::MulIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINSI:
+ return arith::MinSIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINUI:
+ return arith::MinUIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXSI:
+ return arith::MaxSIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXUI:
+ return arith::MaxUIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::AND:
+ return arith::AndIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::OR:
+ return arith::OrIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::XOR:
+ return arith::XOrIOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINNUMF:
+ return arith::MinNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXNUMF:
+ return arith::MaxNumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MINIMUMF:
+ return arith::MinimumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ case vector::CombiningKind::MAXIMUMF:
+ return arith::MaximumFOp::create(rewriter, loc, lhs, rhs).getResult();
+ default:
+ llvm_unreachable("Unsupported reduction kind");
+ }
+}
+
+/// This pattern transforms vector.multi_dim_reduction operations from
+/// workgroup-level to subgroup-level execution with support for multiple
+/// reduction dimensions.
+///
+/// Steps include:
+/// 1. LOCAL REDUCTION :
+/// - Each subgroup performs local reduction on its data slice
+/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
+/// phase
+///
+/// 2. CROSS-SUBGROUP :
+/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
+/// reduction dims)
+/// - If not needed, adds original accumulator and returns local results
+///
+/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
+/// a) SLM Layout Design:
+/// - Rows: subgroups participating in reduction (product of sg_layout in
+/// reduction dims)
+/// - Cols: total result elements across non-reduction dimensions
+///
+/// b) Store Phase:
+/// - Each subgroup stores its local reduction result to SLM
+/// - Row offset: linearized index of subgroup in reduction dimensions
+/// - Col offset: linearized index of subgroup in non-reduction dimensions
+///
+/// c) Load and Final Reduction Phase:
+/// - Each subgroup loads a column of data (all reduction participants for
+/// its position)
+/// - Performs final reduction along the loaded dimension
+/// - Adds original accumulator to get final result
+///
struct WgToSgMultiDimReductionOp
: public OpConversionPattern<vector::MultiDimReductionOp> {
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
@@ -1176,9 +1293,6 @@ struct WgToSgMultiDimReductionOp
return failure();
auto reductionDims = llvm::to_vector(op.getReductionDims());
- if (reductionDims.size() != 1)
- return rewriter.notifyMatchFailure(
- op, "Only single dimension reduction is supported");
// Get sg_layout and sg_data from the parent layout
SmallVector<int64_t> sgLayout;
@@ -1194,20 +1308,10 @@ struct WgToSgMultiDimReductionOp
// Step 1: perform local subgroup reductions with ZERO accumulator
SmallVector<Value> localReductions;
- auto sources = adaptor.getSource();
- auto accs = adaptor.getAcc();
-
- SmallVector<Value> expandedAccs;
- if (accs.size() == 1 && sources.size() > 1) {
- for (size_t i = 0; i < sources.size(); ++i)
- expandedAccs.push_back(accs[0]);
- } else
- expandedAccs = llvm::to_vector(accs);
-
SmallVector<int64_t> sgShape =
getSgShapeAndCount(originalSrcShape, layout).first;
- VectorType newDstType = VectorType::get({sgShape}, elemTy);
- for (auto [sgSrc, sgAcc] : llvm::zip(sources, expandedAccs)) {
+ VectorType newDstType = VectorType::get(sgShape, elemTy);
+ for (auto sgSrc : adaptor.getSource()) {
// Create ZERO accumulator for local reduction
auto zeroLocalAcc = arith::ConstantOp::create(
rewriter, loc, newDstType,
@@ -1219,26 +1323,32 @@ struct WgToSgMultiDimReductionOp
localReductions.push_back(localReduce.getResult());
}
- // Check if cross-subgroup reduction is needed
- int64_t reductionDim = reductionDims[0];
- bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1);
+ // Check if cross-subgroup reduction is needed for any reduction dimension
+ bool needsCrossSubgroupReduction = false;
+ SmallVector<int64_t> crossSgReductionDims;
+ for (int64_t reductionDim : reductionDims) {
+ if (sgLayout[reductionDim] > 1) {
+ needsCrossSubgroupReduction = true;
+ crossSgReductionDims.push_back(reductionDim);
+ }
+ }
// If no cross-subgroup reduction needed, add accumulator and return
if (!needsCrossSubgroupReduction) {
SmallVector<Value> results;
for (auto localResult : localReductions) {
- auto finalResult = arith::AddFOp::create(rewriter, loc, localResult,
- adaptor.getAcc()[0]);
- if (auto defOp = finalResult.getResult().getDefiningOp())
+ auto finalResult = reductionOpKind(rewriter, loc, op.getKind(),
+ localResult, adaptor.getAcc()[0]);
+ if (auto defOp = finalResult.getDefiningOp())
xegpu::setDistributeLayoutAttr(defOp->getResult(0),
layout.dropSgLayoutAndData());
- results.push_back(finalResult.getResult());
+ results.push_back(finalResult);
}
rewriter.replaceOpWithMultiple(op, {results});
return success();
}
- // Step 2: Cross-subgroup reduction using SLM
+ // Step 2: cross-subgroup reduction using SLM
// Calculate total elements in local result
int64_t localElements = computeProduct(sgShape);
@@ -1250,16 +1360,16 @@ struct WgToSgMultiDimReductionOp
rewriter, loc, storeType2D, localReductions[0]);
Value storeData = storeShapeCast.getResult();
- // Calculate SLM shape
- int64_t totalReductionSubgroups =
- sgLayout[static_cast<size_t>(reductionDims[0])];
+ // Calculate SLM shape - rows for sg's in reduction dims, cols for total
+ // result elements across all subgroups in non-reduction dimensions
+ int64_t totalReductionSubgroups = 1;
+ for (int64_t dim : crossSgReductionDims) {
+ totalReductionSubgroups *= sgLayout[dim];
+ }
// Total result elements across all subgroups in non-reduction dimensions
- int64_t totalResultElements = localElements;
- for (size_t i = 0; i < sgLayout.size(); ++i) {
- if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i)))
- totalResultElements *= sgLayout[i];
- }
+ int64_t totalResultElements =
+ localElements * computeProduct(sgLayout) / totalReductionSubgroups;
SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
totalResultElements};
@@ -1293,29 +1403,27 @@ struct WgToSgMultiDimReductionOp
return failure();
SmallVector<Value> sgIds = *sgIdsResult;
- // Row offset is simply the subgroup ID along the reduction dimension
- Value rowOffset = sgIds[reductionDim];
+ // Row offset: linearize reduction dimension indices
+ Value rowOffsetStore = linearizeSubgroupIndices(
+ rewriter, loc, sgIds, crossSgReductionDims, sgLayout);
- // Column offset: linearize all non-reduction dimensions and multiply by
- // localElements
- Value colOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- int64_t currentStride = 1;
+ // Column offset: linearize non-reduction dimension indices
+ SmallVector<int64_t> nonReductionDims;
for (size_t i = 0; i < sgLayout.size(); ++i) {
- if (static_cast<int64_t>(i) != reductionDim) { // Skip reduction dimension
- Value dimVal = sgIds[i];
- Value strideVal =
- arith::ConstantIndexOp::create(rewriter, loc, currentStride);
- Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
- colOffset = arith::AddIOp::create(rewriter, loc, colOffset, term);
- currentStride *= sgLayout[i];
+ if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i))) {
+ nonReductionDims.push_back(static_cast<int64_t>(i));
}
}
+
+ Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds,
+ nonReductionDims, sgLayout);
+
Value localElementsVal =
arith::ConstantIndexOp::create(rewriter, loc, localElements);
colOffset =
arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
- SmallVector<OpFoldResult> storeOffsets2D = {rowOffset, colOffset};
+ SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
storeOffsets2D, /*layout=*/nullptr);
@@ -1328,10 +1436,9 @@ struct WgToSgMultiDimReductionOp
// Load offsets - each subgroup loads its column based on non-reduction
// position
- Value loadOffsetY = arith::ConstantIndexOp::create(rewriter, loc, 0);
- Value loadOffsetX = colOffset;
+ Value rowOffsetLoad = arith::ConstantIndexOp::create(rewriter, loc, 0);
- SmallVector<OpFoldResult> loadOffsets2D = {loadOffsetY, loadOffsetX};
+ SmallVector<OpFoldResult> loadOffsets2D = {rowOffsetLoad, colOffset};
auto loadOp = xegpu::LoadMatrixOp::create(
rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
@@ -1370,14 +1477,14 @@ struct WgToSgMultiDimReductionOp
}
}
- auto finalResult =
- arith::AddFOp::create(rewriter, loc, finalReduce.getResult(), accToAdd);
+ auto finalResult = reductionOpKind(rewriter, loc, op.getKind(),
+ finalReduce.getResult(), accToAdd);
- if (auto defOp = finalResult.getResult().getDefiningOp())
+ if (auto defOp = finalResult.getDefiningOp())
xegpu::setDistributeLayoutAttr(defOp->getResult(0),
layout.dropSgLayoutAndData());
- rewriter.replaceOp(op, finalResult.getResult());
+ rewriter.replaceOp(op, finalResult);
return success();
}
};
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 13e7be6b2fa27..342fef942be51 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -5,6 +5,9 @@
// CHECK-DAG: #map2 = affine_map<()[s0] -> (0)>
// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 floordiv 4)>
// CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)>
+// CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -640,68 +643,72 @@ gpu.module @test_distribution {
}
// CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_1
-// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
-gpu.func @vector_reduce_cross_sg_dim_1(%src: memref<?xf32>) {
- // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x32xf32>
- // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32xindex>
- // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32xi1>
- // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
- // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
- // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [1] : vector<1x1x32xf32> to vector<1x32xf32>
- // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x32xf32>
- // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
- // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
- // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map2()[%[[SGID]]]
- // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
- // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
- // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
- // CHECK-DAG: gpu.barrier
- // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
- // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
- // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<32x32xf32> to vector<32xf32>
- // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x32xf32> to vector<32xf32>
- // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<32xf32>
- %cst_3 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} dense<1.0> : vector<1x32xf32>
- %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<0> : vector<1x32x32xindex>
- %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<true> : vector<1x32x32xi1>
- %14 = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} : memref<?xf32>, vector<1x32x32xindex>, vector<1x32x32xi1> -> vector<1x32x32xf32>
- %15 = vector.multi_reduction <add>, %14, %cst_3 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} [1] : vector<1x32x32xf32> to vector<1x32xf32>
- // CHECK-DAG: gpu.return
- gpu.return
-}
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
+ gpu.func @vector_reduce_cross_sg_dim_1(%src: memref<?xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x32xf32>
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32xindex>
+ // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32xi1>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+ // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+ // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [1] : vector<1x1x32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map2()[%[[SGID]]]
+ // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL2]] : index
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL3]] : index
+ // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
+ // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<32x32xf32> to vector<32xf32>
+ // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x32xf32> to vector<32xf32>
+ // CHECK-DAG: %{{.*}} = arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<32xf32>
+ // CHECK-DAG: gpu.return
+ %cst_3 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} dense<1.0> : vector<1x32xf32>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<0> : vector<1x32x32xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<true> : vector<1x32x32xi1>
+ %14 = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} : memref<?xf32>, vector<1x32x32xindex>, vector<1x32x32xi1> -> vector<1x32x32xf32>
+ %15 = vector.multi_reduction <add>, %14, %cst_3 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} [1] : vector<1x32x32xf32> to vector<1x32xf32>
+ gpu.return
+ }
// CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_0
// CHECK-SAME: (%[[ARG0:.*]]: memref<256x128xf32>)
gpu.func @vector_reduce_cross_sg_dim_0(%src: memref<256x128xf32>) {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[REM4:.*]] = arith.remui %[[SGID]], %[[C4:.*]]
- // CHECK-DAG: %[[DIV4:.*]] = arith.divui %[[SGID]], %[[C4:.*]]
- // CHECK-DAG: %[[REM8:.*]] = arith.remui %[[DIV4]], %[[C8:.*]]
- // CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[REM8]], %[[C32:.*]]
- // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM4]], %[[C32:.*]]
- // CHECK-DAG: %[[REM256:.*]] = arith.remui %[[MUL1]], %[[C256:.*]]
- // CHECK-DAG: %[[REM128:.*]] = arith.remui %[[MUL2]], %[[C128:.*]]
- // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[REM256]], %[[REM128]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
- // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32> -> vector<32x32xf32>
+ // CHECK-DAG: %[[REM1:.*]] = arith.remui %[[SGID]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[DIV1:.*]] = arith.divui %[[SGID]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[REM2:.*]] = arith.remui %[[DIV1]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[REM2]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM1]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[REM3:.*]] = arith.remui %[[MUL1]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[REM4:.*]] = arith.remui %[[MUL2]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[REM3]], %[[REM4]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
+ // CHECK-DAG: %[[LOAD_ND:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32> -> vector<32x32xf32>
// CHECK-DAG: %[[CST_LOCAL:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
- // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_LOCAL]] [0] : vector<32x32xf32> to vector<32xf32>
+ // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_ND]], %[[CST_LOCAL]] [0] : vector<32x32xf32> to vector<32xf32>
// CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<32xf32> to vector<1x32xf32>
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
// CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map3()[%[[SGID2]]]
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map4()[%[[SGID2]]]
- // CHECK-DAG: %[[MUL_AFFINE:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
- // CHECK-DAG: %[[ADD_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL_AFFINE]] : index
- // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD_OFFSET]], %[[C32:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
+ // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL4]] : index
+ // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD1]], %[[C32:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
// CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
@@ -715,7 +722,54 @@ gpu.func @vector_reduce_cross_sg_dim_1(%src: memref<?xf32>) {
-> vector<256x128xf32>
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
: vector<256x128xf32> to vector<128xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @vector_reduce_multi_dim
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
+ gpu.func @vector_reduce_multi_dim(%src: memref<?xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32x32xindex>
+ // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32x32xi1>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32x32xindex>, vector<1x1x32x32xi1> -> vector<1x1x32x32xf32>
+ // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
+ // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [2, 3] : vector<1x1x32x32xf32> to vector<1x1xf32>
+ // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x1xf32> to vector<1x1xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi8, 3>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<16x4xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map5()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply #map6()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply #map7()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE6:.*]] = affine.apply #map4()[%[[SGID]]]
+ // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4:.*]] : index
+ // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
+ // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
+ // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
+ // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C1:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x4xf32>, index, index -> vector<16x1xf32>
+ // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+ // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<16x1xf32> to vector<1xf32>
+ // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x1xf32> to vector<1xf32>
+ // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<1xf32>
// CHECK-DAG: gpu.return
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [1, 1, 32, 32]>, dims = [2, 3]>} dense<0.0> : vector<2x2xf32>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [1, 1, 32, 32]>} dense<0> : vector<2x2x128x128xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [1, 1, 32, 32]>} dense<true> : vector<2x2x128x128xi1>
+ %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [1, 1, 32, 32]>} : memref<?xf32>, vector<2x2x128x128xindex>, vector<2x2x128x128xi1> -> vector<2x2x128x128xf32>
+ %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 4, 4], sg_data = [1, 1, 32, 32]>, dims = [2, 3]>} [2, 3] : vector<2x2x128x128xf32> to vector<2x2xf32>
gpu.return
}
}
>From 2533d43456d779eebefd4f3fd0cdcc3d60af293d Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 12 Dec 2025 21:42:58 +0000
Subject: [PATCH 3/6] Fix CI
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index ff234469125bc..54f912bef5637 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1243,9 +1243,8 @@ static Value reductionOpKind(ConversionPatternRewriter &rewriter, Location loc,
return arith::MinimumFOp::create(rewriter, loc, lhs, rhs).getResult();
case vector::CombiningKind::MAXIMUMF:
return arith::MaximumFOp::create(rewriter, loc, lhs, rhs).getResult();
- default:
- llvm_unreachable("Unsupported reduction kind");
}
+ llvm_unreachable("unsupported OpKind");
}
/// This pattern transforms vector.multi_dim_reduction operations from
>From be2036ae096e78e6afa342e406a7fc7f4abf639e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 12 Jan 2026 16:04:48 +0000
Subject: [PATCH 4/6] Address Feedback
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 56 ++++---------------
1 file changed, 10 insertions(+), 46 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 52b84dd0812ce..a1b4b5e502f3a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1230,46 +1230,6 @@ static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
return linearizedOffset;
}
-// Helper function to create the appropriate binary operation based on reduction
-// kind
-static Value reductionOpKind(ConversionPatternRewriter &rewriter, Location loc,
- vector::CombiningKind kind, Value lhs, Value rhs) {
- Type elemType = getElementTypeOrSelf(lhs.getType());
- bool isFloat = isa<FloatType>(elemType);
-
- switch (kind) {
- case vector::CombiningKind::ADD:
- return isFloat ? arith::AddFOp::create(rewriter, loc, lhs, rhs).getResult()
- : arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::MUL:
- return isFloat ? arith::MulFOp::create(rewriter, loc, lhs, rhs).getResult()
- : arith::MulIOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::MINSI:
- return arith::MinSIOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::MINUI:
- return arith::MinUIOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::MAXSI:
- return arith::MaxSIOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::MAXUI:
- return arith::MaxUIOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::AND:
- return arith::AndIOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::OR:
- return arith::OrIOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::XOR:
- return arith::XOrIOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::MINNUMF:
- return arith::MinNumFOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::MAXNUMF:
- return arith::MaxNumFOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::MINIMUMF:
- return arith::MinimumFOp::create(rewriter, loc, lhs, rhs).getResult();
- case vector::CombiningKind::MAXIMUMF:
- return arith::MaximumFOp::create(rewriter, loc, lhs, rhs).getResult();
- }
- llvm_unreachable("unsupported OpKind");
-}
-
/// This pattern transforms vector.multi_dim_reduction operations from
/// workgroup-level to subgroup-level execution with support for multiple
/// reduction dimensions.
@@ -1282,7 +1242,7 @@ static Value reductionOpKind(ConversionPatternRewriter &rewriter, Location loc,
///
/// 2. CROSS-SUBGROUP :
/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
-/// reduction dims)
+/// reduction dims & sgData[reduction dims] < wgData[reduction dims])
/// - If not needed, adds original accumulator and returns local results
///
/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
@@ -1357,7 +1317,11 @@ struct WgToSgMultiDimReductionOp
bool needsCrossSubgroupReduction = false;
SmallVector<int64_t> crossSgReductionDims;
for (int64_t reductionDim : reductionDims) {
- if (sgLayout[reductionDim] > 1) {
+ bool needsCrossSg =
+ (sgLayout[reductionDim] > 1) &&
+ (sgData[reductionDim] < originalSrcShape[reductionDim]);
+
+ if (needsCrossSg) {
needsCrossSubgroupReduction = true;
crossSgReductionDims.push_back(reductionDim);
}
@@ -1367,8 +1331,8 @@ struct WgToSgMultiDimReductionOp
if (!needsCrossSubgroupReduction) {
SmallVector<Value> results;
for (auto localResult : localReductions) {
- auto finalResult = reductionOpKind(rewriter, loc, op.getKind(),
- localResult, adaptor.getAcc()[0]);
+ auto finalResult = vector::makeArithReduction(
+ rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
if (auto defOp = finalResult.getDefiningOp())
xegpu::setDistributeLayoutAttr(defOp->getResult(0),
layout.dropSgLayoutAndData());
@@ -1507,8 +1471,8 @@ struct WgToSgMultiDimReductionOp
}
}
- auto finalResult = reductionOpKind(rewriter, loc, op.getKind(),
- finalReduce.getResult(), accToAdd);
+ auto finalResult = vector::makeArithReduction(
+ rewriter, loc, op.getKind(), finalReduce.getResult(), accToAdd);
if (auto defOp = finalResult.getDefiningOp())
xegpu::setDistributeLayoutAttr(defOp->getResult(0),
>From 139f416c38e1019a6d773b299d17d5d1debf8a3a Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 12 Jan 2026 19:58:29 +0000
Subject: [PATCH 5/6] add empty layout for store_matrix
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 17 +--
.../Transforms/XeGPUWgToSgDistribute.cpp | 117 +++++++++++++++---
mlir/test/Dialect/XeGPU/invalid.mlir | 19 ---
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 6 +-
4 files changed, 110 insertions(+), 49 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index ccf17da26c942..3a5cd5e8df0b7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -213,13 +213,9 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
- // A valid layout must include at least one of sg_layout and lane_layout.
- // sg_layout is essential for Workgroup layout, while lane_layout is
- // required for Subgroup layout.
- if (!sg_layout && !inst_data && !lane_layout) {
- return emitError()
- << "expected at least one of sg_layout, inst_data or lane_layout";
- }
+ // Special case for store_matrix
+ if (!sg_layout && !inst_data && !lane_layout)
+ return success();
// generate code to check sg_laout, inst_data and lane_layout having the same
// rank if they are not null.
@@ -478,15 +474,14 @@ DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
LogicalResult
SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
- if (!parent || !dims)
- return emitError() << "expected parent layout and dims attribute";
- int64_t rank = parent.getRank();
+ if (!dims)
+ return emitError() << "expected dims attribute";
// check every element in dims is unique and smaller than rank
llvm::SmallDenseSet<int64_t> seen;
for (int64_t dim : dims.asArrayRef()) {
- if (dim < 0 || dim >= rank)
+ if (dim < 0)
return emitError() << "invalid dim (" << dim << ") in slice attribute.";
if (!seen.insert(dim).second)
return emitError() << "repeated dim (" << dim << ") in slice attribute.";
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a1b4b5e502f3a..e60ddd8159649 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1184,6 +1184,89 @@ struct WgToSgVectorShapeCastOp
}
};
+static Value createNeutralAccumulator(ConversionPatternRewriter &rewriter,
+ Location loc, VectorType type,
+ vector::CombiningKind kind) {
+ Type elemTy = type.getElementType();
+
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ case vector::CombiningKind::XOR:
+ case vector::CombiningKind::OR:
+ case vector::CombiningKind::MAXUI:
+ return arith::ConstantOp::create(
+ rewriter, loc, type,
+ DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
+
+ case vector::CombiningKind::MUL:
+ case vector::CombiningKind::AND:
+ return arith::ConstantOp::create(
+ rewriter, loc, type,
+ DenseElementsAttr::get(type, rewriter.getOneAttr(elemTy)));
+
+ case vector::CombiningKind::MINSI:
+ // Use max signed int value for signed integer min
+ if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
+ auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
+ return arith::ConstantOp::create(
+ rewriter, loc, type,
+ DenseElementsAttr::get(type,
+ rewriter.getIntegerAttr(elemTy, maxVal)));
+ }
+ return nullptr;
+
+ case vector::CombiningKind::MINUI:
+ if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
+ auto maxVal = APInt::getMaxValue(intTy.getWidth());
+ return arith::ConstantOp::create(
+ rewriter, loc, type,
+ DenseElementsAttr::get(type,
+ rewriter.getIntegerAttr(elemTy, maxVal)));
+ }
+ return nullptr;
+
+ case vector::CombiningKind::MAXSI:
+ if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
+ auto minVal = APInt::getSignedMinValue(intTy.getWidth());
+ return arith::ConstantOp::create(
+ rewriter, loc, type,
+ DenseElementsAttr::get(type,
+ rewriter.getIntegerAttr(elemTy, minVal)));
+ }
+ return nullptr;
+
+ case vector::CombiningKind::MAXUI:
+ return arith::ConstantOp::create(
+ rewriter, loc, type,
+ DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
+
+ case vector::CombiningKind::MINNUMF:
+ case vector::CombiningKind::MINIMUMF:
+ // Use +infinity for float min operations
+ if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
+ auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
+ return arith::ConstantOp::create(
+ rewriter, loc, type,
+ DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, posInf)));
+ }
+ return nullptr;
+
+ case vector::CombiningKind::MAXNUMF:
+ case vector::CombiningKind::MAXIMUMF:
+ // Use -infinity for float max operations
+ if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
+ auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true);
+ return arith::ConstantOp::create(
+ rewriter, loc, type,
+ DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, negInf)));
+ }
+ return nullptr;
+
+ default:
+ return nullptr; // Unsupported reduction kind
+ }
+}
+
/// This function converts multi-dimensional subgroup indices into a single
/// linear offset. It's used to calculate memory offsets in SLM for
/// cross-subgroup reduction coordination.
@@ -1303,32 +1386,29 @@ struct WgToSgMultiDimReductionOp
VectorType newDstType = VectorType::get(sgShape, elemTy);
for (auto sgSrc : adaptor.getSource()) {
// Create ZERO accumulator for local reduction
- auto zeroLocalAcc = arith::ConstantOp::create(
- rewriter, loc, newDstType,
- DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy)));
+ auto neutralLocalAcc =
+ createNeutralAccumulator(rewriter, loc, newDstType, op.getKind());
// Local reduction with ZERO accumulator
auto localReduce = vector::MultiDimReductionOp::create(
- rewriter, loc, newDstType, op.getKind(), sgSrc,
- zeroLocalAcc.getResult(), reductionDims);
+ rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
+ reductionDims);
localReductions.push_back(localReduce.getResult());
}
// Check if cross-subgroup reduction is needed for any reduction dimension
- bool needsCrossSubgroupReduction = false;
SmallVector<int64_t> crossSgReductionDims;
for (int64_t reductionDim : reductionDims) {
- bool needsCrossSg =
+ bool needsCrossSubgroupReduction =
(sgLayout[reductionDim] > 1) &&
(sgData[reductionDim] < originalSrcShape[reductionDim]);
- if (needsCrossSg) {
- needsCrossSubgroupReduction = true;
+ if (needsCrossSubgroupReduction) {
crossSgReductionDims.push_back(reductionDim);
}
}
// If no cross-subgroup reduction needed, add accumulator and return
- if (!needsCrossSubgroupReduction) {
+ if (crossSgReductionDims.empty()) {
SmallVector<Value> results;
for (auto localResult : localReductions) {
auto finalResult = vector::makeArithReduction(
@@ -1419,8 +1499,15 @@ struct WgToSgMultiDimReductionOp
SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
+ auto storeMatrixLayout = xegpu::SliceAttr::get(
+ rewriter.getContext(),
+ xegpu::LayoutAttr::get(rewriter.getContext(), /*sg_layout =*/nullptr,
+ /*sg_data =*/nullptr,
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr),
+ dyn_cast<xegpu::SliceAttr>(layout).getDims());
xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
- storeOffsets2D, /*layout=*/nullptr);
+ storeOffsets2D, /*layout=*/storeMatrixLayout);
gpu::BarrierOp::create(rewriter, loc);
@@ -1443,14 +1530,12 @@ struct WgToSgMultiDimReductionOp
SmallVector<int64_t> finalResultShape = {localElements};
VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
- // Create ZERO accumulator for final reduction
- auto zeroFinalAcc = arith::ConstantOp::create(
- rewriter, loc, finalResultType,
- DenseElementsAttr::get(finalResultType, rewriter.getZeroAttr(elemTy)));
+ auto neutralFinalAcc =
+ createNeutralAccumulator(rewriter, loc, finalResultType, op.getKind());
auto finalReduce = vector::MultiDimReductionOp::create(
rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
- zeroFinalAcc.getResult(), finalReductionDims);
+ neutralFinalAcc, finalReductionDims);
// Step 7: Add the original accumulator at the end
Value originalAcc = adaptor.getAcc()[0];
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 67faa60f2835e..1d4953a5134d9 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -695,16 +695,6 @@ func.func @convert_layout_unmatch(%a: vector<32x64xf16>) {
gpu.return
}
-// -----
-func.func @tensor_desc_invalid_layout_attr(%src: ui64, %offsets: vector<16xindex>) {
- %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
- !xegpu.tensor_desc<16x2xf32,
- #xegpu.scatter_tdesc_attr<chunk_size = 2>,
- // expected-error at +1 {{expected at least one of sg_layout, inst_data or lane_layout}}
- #xegpu.layout<sg_data = [16, 2], lane_data = [1, 2]>>
- return
-}
-
// -----
func.func @tensor_desc_rank_mismatch(%src: ui64, %offsets: vector<16xindex>) {
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
@@ -824,15 +814,6 @@ func.func @slice_attr_repeat_dim() {
return
}
-// -----
-#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
-// expected-error at +1 {{invalid dim (3) in slice attribute}}
-#s = #xegpu.slice<#l, dims = [3]>
-func.func @slice_attr_repeat_dim() {
- %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
- return
-}
-
// -----
func.func @create_mem_desc_non_slm() {
%m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index bfb29e182c2bb..2c601d5069ff0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -674,7 +674,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
// CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL3]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [1]>}>: vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
@@ -717,7 +717,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
// CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL4]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD1]], %[[C32:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [0]>}>: vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
// CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
@@ -766,7 +766,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
// CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C1:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [2, 3]>}>: vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x4xf32>, index, index -> vector<16x1xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
>From bc1ee0b34f0fc84838626f14c2a82759c8b96eba Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 13 Jan 2026 19:40:23 +0000
Subject: [PATCH 6/6] Fix CI
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index e60ddd8159649..6c45d18f4bbef 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1193,7 +1193,6 @@ static Value createNeutralAccumulator(ConversionPatternRewriter &rewriter,
case vector::CombiningKind::ADD:
case vector::CombiningKind::XOR:
case vector::CombiningKind::OR:
- case vector::CombiningKind::MAXUI:
return arith::ConstantOp::create(
rewriter, loc, type,
DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
More information about the Mlir-commits
mailing list