[Mlir-commits] [mlir] [MLIR][XeGPU] Support leading unit dim for reduction in sg to wi pass (PR #185110)
Nishant Patel
llvmlistbot at llvm.org
Fri Mar 6 13:49:18 PST 2026
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/185110
>From cf58c338c18eef7bf9056cf179177462be44e582 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 6 Mar 2026 19:29:32 +0000
Subject: [PATCH 1/2] Support leading unit dim
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 44 +++++++----
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 46 ++++++++----
.../XeGPU/subgroup-distribute-unit.mlir | 74 +++++++++++++++++++
3 files changed, 135 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 38bc95d39c2c6..2a889ffe35896 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1321,17 +1321,31 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
unsigned operandIdx = yieldOperand->getOperandNumber();
VectorType sourceType = reductionOp.getSourceVectorType();
- // Only 2D vectors are supported.
- if (sourceType.getRank() != 2)
+ int64_t sourceRank = sourceType.getRank();
+ // Need at least a 2D source vector.
+ if (sourceRank < 2)
return rewriter.notifyMatchFailure(warpOp,
- "Only 2D reductions are supported.");
+ "Only 2D+ reductions are supported.");
+ // Leading dimensions (first rank-2) must be unit (size 1).
+ for (int64_t i = 0; i < sourceRank - 2; ++i) {
+ if (sourceType.getShape()[i] != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Only unit dimensions allowed for the leading dimensions.");
+ }
+ // Effective dimension indices (last 2 dims of the source).
+ int64_t dim0Idx = sourceRank - 2;
+ int64_t dim1Idx = sourceRank - 1;
ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
- // Only 1 reduction dimension supported. This also ensures that the result
- // is vector type.
- if (reductionDims.size() != 1)
+ // Find effective reduction dims among the last 2 dims (skip leading dims).
+ SmallVector<int64_t> effectiveReductionDims;
+ for (int64_t d : reductionDims) {
+ if (d == dim0Idx || d == dim1Idx)
+ effectiveReductionDims.push_back(d);
+ }
+ if (effectiveReductionDims.size() != 1)
return rewriter.notifyMatchFailure(
- warpOp, "Only 1 reduction dimension is supported.");
- int64_t reductionDim = reductionDims[0];
+ warpOp, "Only 1 non-trivial effective reduction dim is supported.");
+ int64_t reductionDim = effectiveReductionDims[0];
VectorType distributedResultType =
cast<VectorType>(warpOp.getResult(operandIdx).getType());
VectorType resultType = cast<VectorType>(reductionOp.getType());
@@ -1344,15 +1358,16 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
return rewriter.notifyMatchFailure(
warpOp, "Failed to distribute the source vector type.");
VectorType sourceDistType = sourceDistTypeOrFailure.value();
- // Only single dimension distribution is supported.
+ // Only single dimension distribution among the last 2 dims is supported.
bool dim0Distributed =
- sourceDistType.getShape()[0] != sourceType.getShape()[0];
+ sourceDistType.getShape()[dim0Idx] != sourceType.getShape()[dim0Idx];
bool dim1Distributed =
- sourceDistType.getShape()[1] != sourceType.getShape()[1];
+ sourceDistType.getShape()[dim1Idx] != sourceType.getShape()[dim1Idx];
if (dim0Distributed && dim1Distributed)
return rewriter.notifyMatchFailure(
warpOp, "Expecting source to be distributed in a single dimension.");
- int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
+ int64_t sourceDistDim =
+ dim0Distributed ? dim0Idx : (dim1Distributed ? dim1Idx : -1);
if (sourceDistDim == -1)
return rewriter.notifyMatchFailure(
warpOp, "Expecting a distributed source vector.");
@@ -1371,8 +1386,9 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
// | dim-1 distributed | 0 | distributed |
// | dim-1 distributed | 1 | broadcasted |
- bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
- (sourceDistDim == 1 && reductionDim == 0);
+ bool isReductionLaneLocal =
+ (sourceDistDim == dim0Idx && reductionDim == dim1Idx) ||
+ (sourceDistDim == dim1Idx && reductionDim == dim0Idx);
if (isReductionLaneLocal && !resultDistributed)
return rewriter.notifyMatchFailure(
warpOp, "Expecting a distributed result for lane-local reduction.");
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 3271e73e0b571..c31eda2aa2a60 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -671,12 +671,19 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
vector::CombiningKind kind,
int64_t reductionDim, Location loc,
PatternRewriter &rewriter) {
- // Expecting a 2D source vector.
- assert(src.getType().getRank() == 2 && "expected a 2D source vector");
VectorType sourceType = src.getType();
- int64_t sourceH = sourceType.getShape()[0];
- int64_t sourceW = sourceType.getShape()[1];
- int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+ int64_t sourceRank = sourceType.getRank();
+ // Expecting at least a 2D source vector. Leading dimensions (all except the
+ // last two) must be unit.
+ assert(sourceRank >= 2 && "expected at least a 2D source vector");
+ for (int64_t i = 0; i < sourceRank - 2; ++i)
+ assert(sourceType.getShape()[i] == 1 &&
+ "expected leading dimensions to be unit");
+ int64_t dim0Idx = sourceRank - 2;
+ int64_t dim1Idx = sourceRank - 1;
+ int64_t sourceH = sourceType.getShape()[dim0Idx];
+ int64_t sourceW = sourceType.getShape()[dim1Idx];
+ int nSlices = (reductionDim == dim0Idx) ? sourceW : sourceH;
// Create a constant vector to hold the result of the reduction.
TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
Value reductionResult = arith::ConstantOp::create(
@@ -688,19 +695,24 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
// For each slice of the source, extract the slice vector, do a reduction
// and, insert the reduced value back to the result vector.
+ int64_t accRank = acc.getType().getRank();
for (int i = 0; i < nSlices; ++i) {
- SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
- if (reductionDim == 1) {
- sliceOffsets = {i, 0};
- sliceSizes = {1, sourceW};
+ // Build nD offsets, sizes, and strides. Leading unit dims get
+ // offset=0, size=1. The last two dims are set based on reductionDim.
+ SmallVector<int64_t> sliceOffsets(sourceRank, 0);
+ SmallVector<int64_t> sliceSizes(sourceRank, 1);
+ SmallVector<int64_t> strides(sourceRank, 1);
+ if (reductionDim == dim1Idx) {
+ sliceOffsets[dim0Idx] = i;
+ sliceSizes[dim1Idx] = sourceW;
} else {
- sliceOffsets = {0, i};
- sliceSizes = {sourceH, 1};
+ sliceOffsets[dim1Idx] = i;
+ sliceSizes[dim0Idx] = sourceH;
}
vector::ExtractStridedSliceOp extractOp =
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
- sliceSizes, {1, 1});
+ sliceSizes, strides);
// Extract strided slice has the same layout as src.
xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
@@ -716,11 +728,15 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
// Extract and reduction results in scalars, so no result layout is needed.
- Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+ // Build multi-dim index into acc (sourceRank-1 dims, i.e. source shape with
+ // the reduction dim removed). Leading unit dims get index 0.
+ SmallVector<int64_t> accIdx(accRank, 0);
+ accIdx[accRank - 1] = i;
+ Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
Value reduction = vector::ReductionOp::create(
rewriter, loc, kind, slice.getResult(), accExtract);
- reductionResult =
- vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+ reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
+ reductionResult, accIdx);
// Insert op should have the same layout as the accumulator.
xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 8b980c5083af3..fc6ccd3f3f887 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -396,6 +396,80 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
}
+// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim
+// CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<1x32xf32>
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME: -> (vector<1x2xf32>, vector<1x16x2xf32>, vector<1x2xf32>) {
+// CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<1x16x32xf32>
+// CHECK: gpu.yield %{{.*}}, %[[SRC]], %[[ACC]] : vector<1x32xf32>, vector<1x16x32xf32>, vector<1x32xf32>
+// CHECK-NEXT: }
+// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [1, 16, 1], strides = [1, 1, 1]} : vector<1x16x2xf32> to vector<1x16x1xf32>
+// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<1x16x1xf32> to vector<16xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[W]]#2[0, 0] : f32 from vector<1x2xf32>
+// CHECK: %[[T4:.*]] = vector.reduction <add>, %[[T2]], %[[T3]] : vector<16xf32> into f32
+// CHECK: %[[T5:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME: {offsets = [0, 0, 1], sizes = [1, 16, 1], strides = [1, 1, 1]} : vector<1x16x2xf32> to vector<1x16x1xf32>
+// CHECK: %[[T6:.*]] = vector.shape_cast %[[T5]] : vector<1x16x1xf32> to vector<16xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[W]]#2[0, 1] : f32 from vector<1x2xf32>
+// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T6]], %[[T7]] : vector<16xf32> into f32
+// CHECK: %[[T9:.*]] = vector.from_elements %[[T4]], %[[T8]] : vector<1x2xf32>
+gpu.func @vector_multi_reduction_3d_leading_unit_dim(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x2xf32>) {
+ %src = "some_def"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ : () -> (vector<1x16x32xf32>)
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>}
+ dense<0.0> : vector<1x32xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
+ layout_operand_1 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>,
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
+ }
+ [1] : vector<1x16x32xf32> to vector<1x32xf32>
+ gpu.yield %1 : vector<1x32xf32>
+ }
+ "some_user_op"(%r) : (vector<1x2xf32>) -> ()
+ gpu.return
+}
+
+
+// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_trivial_reduction
+// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME: -> (vector<1x1xf32>, vector<1x1x1xf32>, vector<1x1xf32>) {
+// CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<1x1x16xf32>
+// CHECK: gpu.yield %{{.*}}, %[[SRC]], %{{.*}} : vector<1x16xf32>, vector<1x1x16xf32>, vector<1x16xf32>
+// CHECK-NEXT: }
+// CHECK: %[[A:.*]] = vector.extract %[[W]]#2[0, 0] : f32 from vector<1x1xf32>
+// CHECK: %[[S:.*]] = vector.extract %[[W]]#1[0, 0, 0] : f32 from vector<1x1x1xf32>
+// CHECK: %[[ADD:.*]] = arith.addf %[[S]], %[[A]] : f32
+// CHECK: %[[BC:.*]] = vector.broadcast %[[ADD]] : f32 to vector<1x1xf32>
+gpu.func @vector_multi_reduction_3d_trivial_reduction(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
+ %src = "some_def"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ : () -> (vector<1x1x16xf32>)
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>}
+ dense<0.0> : vector<1x16xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
+ layout_operand_1 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>,
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
+ }
+ [1] : vector<1x1x16xf32> to vector<1x16xf32>
+ gpu.yield %1 : vector<1x16xf32>
+ }
+ "some_user_op"(%r) : (vector<1x1xf32>) -> ()
+ gpu.return
+}
+
+
// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
>From 785e64569c6cf8d110ef7877f3667dc495a48f00 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 6 Mar 2026 21:46:03 +0000
Subject: [PATCH 2/2] Clean up
---
.../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 16 +++++++---------
1 file changed, 7 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2a889ffe35896..d7467308b9676 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1336,16 +1336,14 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
int64_t dim0Idx = sourceRank - 2;
int64_t dim1Idx = sourceRank - 1;
ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
- // Find effective reduction dims among the last 2 dims (skip leading dims).
- SmallVector<int64_t> effectiveReductionDims;
- for (int64_t d : reductionDims) {
- if (d == dim0Idx || d == dim1Idx)
- effectiveReductionDims.push_back(d);
- }
- if (effectiveReductionDims.size() != 1)
+ if (reductionDims.size() != 1)
+ return rewriter.notifyMatchFailure(warpOp,
+ "Only 1 reduction dim is supported.");
+ int64_t reductionDim = reductionDims[0];
+ // The reduction dim must be among the last 2 dims.
+ if (reductionDim != dim0Idx && reductionDim != dim1Idx)
return rewriter.notifyMatchFailure(
- warpOp, "Only 1 non-trivial effective reduction dim is supported.");
- int64_t reductionDim = effectiveReductionDims[0];
+ warpOp, "Reduction dim must be among the last 2 dimensions.");
VectorType distributedResultType =
cast<VectorType>(warpOp.getResult(operandIdx).getType());
VectorType resultType = cast<VectorType>(reductionOp.getType());
More information about the Mlir-commits
mailing list