[Mlir-commits] [mlir] [MLIR][XeGPU] Support leading unit dims in vector.multi_reduction in sg to wi pass (PR #188767)
Nishant Patel
llvmlistbot at llvm.org
Thu Mar 26 08:23:11 PDT 2026
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/188767
None
>From 6af4262cab5a33784b0de4bdf813d5b33fc1e629 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 26 Mar 2026 03:47:47 +0000
Subject: [PATCH] Support leading unit dims in reduction
---
.../XeGPUSgToWiDistributeExperimental.cpp | 11 +++
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 50 +++++++----
.../XeGPU/sg-to-wi-experimental-unit.mlir | 87 +++++++++++++++++++
3 files changed, 130 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 0961ddfb92040..0b19ddf6163a0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -588,6 +588,17 @@ struct SgToWiMultiDimReduction
assert(reductionDims.size() == 1 &&
"Expecting single reduction dimension for subgroup multi "
"reduction op");
+ // For rank > 2, ensure leading dimensions are unit.
+ VectorType sourceType = op.getSourceVectorType();
+ int64_t rank = sourceType.getRank();
+ if (rank > 2) {
+ ArrayRef<int64_t> shape = sourceType.getShape();
+ if (llvm::any_of(shape.take_front(rank - 2),
+ [](int64_t d) { return d != 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "only unit leading dimensions are supported for "
+ "multi_reduction with rank > 2");
+ }
if (isReductionLaneLocal(op)) {
auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
VectorType resVecTy = dyn_cast<VectorType>(op.getType());
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index f60635830cc74..930c9e898dec4 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -750,11 +750,18 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
TypedValue<VectorType> src, TypedValue<VectorType> acc,
vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize,
Location loc, PatternRewriter &rewriter) {
- // Expecting a 2D source vector.
- assert(src.getType().getRank() == 2 && "expected a 2D source vector");
VectorType sourceType = src.getType();
- int64_t sourceH = sourceType.getShape()[0];
- int64_t sourceW = sourceType.getShape()[1];
+ int64_t sourceRank = sourceType.getRank();
+ // Expecting at least a 2D source vector. Leading dimensions (all except the
+ // last two) must be unit.
+ assert(sourceRank >= 2 && "expected at least a 2D source vector");
+ for (int64_t i = 0; i < sourceRank - 2; ++i)
+ assert(sourceType.getShape()[i] == 1 &&
+ "expected leading dimensions to be unit");
+ int64_t rowIdx = sourceRank - 2;
+ int64_t columnIdx = sourceRank - 1;
+ int64_t sourceH = sourceType.getShape()[rowIdx];
+ int64_t sourceW = sourceType.getShape()[columnIdx];
// Create a constant vector to hold the result of the reduction.
TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
@@ -763,39 +770,46 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
DenseElementsAttr::get(acc.getType(), zeroAttr));
// nSlices is the number of reduction operations needed to reduce the entire
- // source vector. For example, if reductionDim is 0, we are reducing across
- // rows, and each slice is a column of the source vector. So the number of
- // slices is the number of columns, which is sourceW.
- int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+ // source vector. For example, if reductionDim is the row dim, we are
+ // reducing across rows, and each slice is a column. So the number of slices
+ // is the number of columns, which is sourceW.
+ int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
// For each slice of the source, extract the slice vector, do a reduction
// and, insert the reduced value back to the result vector.
+ int64_t accRank = acc.getType().getRank();
for (int i = 0; i < nSlices; ++i) {
- SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
- if (reductionDim == 1) {
- sliceOffsets = {i, 0};
- sliceSizes = {1, sourceW};
+ // Build nD offsets, sizes, and strides. Leading unit dims get
+ // offset=0, size=1. The last two dims are set based on reductionDim.
+ SmallVector<int64_t> sliceOffsets(sourceRank, 0);
+ SmallVector<int64_t> sliceSizes(sourceRank, 1);
+ SmallVector<int64_t> strides(sourceRank, 1);
+ if (reductionDim == columnIdx) {
+ sliceOffsets[rowIdx] = i;
+ sliceSizes[columnIdx] = sourceW;
} else {
- sliceOffsets = {0, i};
- sliceSizes = {sourceH, 1};
+ sliceOffsets[columnIdx] = i;
+ sliceSizes[rowIdx] = sourceH;
}
vector::ExtractStridedSliceOp extractOp =
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
- sliceSizes, {1, 1});
+ sliceSizes, strides);
int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
vector::ShapeCastOp slice = vector::ShapeCastOp::create(
rewriter, loc,
VectorType::get({nSliceElements}, sourceType.getElementType()),
extractOp.getResult());
- Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+ SmallVector<int64_t> accIdx(accRank, 0);
+ accIdx[accRank - 1] = i;
+ Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
Value fullReduce =
xegpu::subgroupReduction(loc, rewriter, slice, kind, reductionSize);
fullReduce =
vector::makeArithReduction(rewriter, loc, kind, fullReduce, accExtract);
- reductionResult =
- vector::InsertOp::create(rewriter, loc, fullReduce, reductionResult, i);
+ reductionResult = vector::InsertOp::create(rewriter, loc, fullReduce,
+ reductionResult, accIdx);
}
return reductionResult;
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 016b393e3d8bc..07c63f5933a9a 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -461,6 +461,93 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
gpu.return
}
+// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32>
+// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x16x2xf32> to vector<1x2xf32>
+// CHECK: gpu.return
+gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local() {
+ %src = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ dense<0.0> : vector<1x16x32xf32>
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>}
+ dense<0.0> : vector<1x32xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
+ }
+ [1] : vector<1x16x32xf32> to vector<1x32xf32>
+ gpu.return
+}
+
+// Cross-lane 3D reduction with leading unit dim.
+// Source distributed on dim 1, reducing dim 1 => cross-lane shuffle.
+// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane
+// CHECK: %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x2xf32>
+// CHECK: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
+// CHECK: %[[SLICE0:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [1, 1, 1], strides = [1, 1, 1]}
+// CHECK: %[[FLAT0:.*]] = vector.shape_cast %[[SLICE0]] : vector<1x1x1xf32> to vector<1xf32>
+// CHECK: %[[ACC0:.*]] = vector.extract %[[ACC]][0, 0] : f32 from vector<1x2xf32>
+// CHECK: %[[RED0:.*]] = vector.reduction <add>, %[[FLAT0]] : vector<1xf32> into f32
+// CHECK: %[[C16_0:.*]] = arith.constant 16 : i32
+// CHECK: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: %[[SHUF0_1:.*]], %{{.*}} = gpu.shuffle xor %[[RED0]], %[[C1]], %[[C16_0]] : f32
+// CHECK: %[[ADD0_1:.*]] = arith.addf %[[RED0]], %[[SHUF0_1]] : f32
+// CHECK: %[[C16_1:.*]] = arith.constant 16 : i32
+// CHECK: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK: %[[SHUF0_2:.*]], %{{.*}} = gpu.shuffle xor %[[ADD0_1]], %[[C2]], %[[C16_1]] : f32
+// CHECK: %[[ADD0_2:.*]] = arith.addf %[[ADD0_1]], %[[SHUF0_2]] : f32
+// CHECK: %[[C16_2:.*]] = arith.constant 16 : i32
+// CHECK: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK: %[[SHUF0_4:.*]], %{{.*}} = gpu.shuffle xor %[[ADD0_2]], %[[C4]], %[[C16_2]] : f32
+// CHECK: %[[ADD0_4:.*]] = arith.addf %[[ADD0_2]], %[[SHUF0_4]] : f32
+// CHECK: %[[C16_3:.*]] = arith.constant 16 : i32
+// CHECK: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK: %[[SHUF0_8:.*]], %{{.*}} = gpu.shuffle xor %[[ADD0_4]], %[[C8]], %[[C16_3]] : f32
+// CHECK: %[[ADD0_8:.*]] = arith.addf %[[ADD0_4]], %[[SHUF0_8]] : f32
+// CHECK: %[[FINAL0:.*]] = arith.addf %[[ADD0_8]], %[[ACC0]] : f32
+// CHECK: %[[INS0:.*]] = vector.insert %[[FINAL0]], %{{.*}} [0, 0] : f32 into vector<1x2xf32>
+// CHECK: %[[SLICE1:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {offsets = [0, 0, 1], sizes = [1, 1, 1], strides = [1, 1, 1]}
+// CHECK: %[[FLAT1:.*]] = vector.shape_cast %[[SLICE1]] : vector<1x1x1xf32> to vector<1xf32>
+// CHECK: %[[ACC1:.*]] = vector.extract %[[ACC]][0, 1] : f32 from vector<1x2xf32>
+// CHECK: %[[RED1:.*]] = vector.reduction <add>, %[[FLAT1]] : vector<1xf32> into f32
+// CHECK: %[[C16_4:.*]] = arith.constant 16 : i32
+// CHECK: %[[C1_1:.*]] = arith.constant 1 : i32
+// CHECK: %[[SHUF1_1:.*]], %{{.*}} = gpu.shuffle xor %[[RED1]], %[[C1_1]], %[[C16_4]] : f32
+// CHECK: %[[ADD1_1:.*]] = arith.addf %[[RED1]], %[[SHUF1_1]] : f32
+// CHECK: %[[C16_5:.*]] = arith.constant 16 : i32
+// CHECK: %[[C2_1:.*]] = arith.constant 2 : i32
+// CHECK: %[[SHUF1_2:.*]], %{{.*}} = gpu.shuffle xor %[[ADD1_1]], %[[C2_1]], %[[C16_5]] : f32
+// CHECK: %[[ADD1_2:.*]] = arith.addf %[[ADD1_1]], %[[SHUF1_2]] : f32
+// CHECK: %[[C16_6:.*]] = arith.constant 16 : i32
+// CHECK: %[[C4_1:.*]] = arith.constant 4 : i32
+// CHECK: %[[SHUF1_4:.*]], %{{.*}} = gpu.shuffle xor %[[ADD1_2]], %[[C4_1]], %[[C16_6]] : f32
+// CHECK: %[[ADD1_4:.*]] = arith.addf %[[ADD1_2]], %[[SHUF1_4]] : f32
+// CHECK: %[[C16_7:.*]] = arith.constant 16 : i32
+// CHECK: %[[C8_1:.*]] = arith.constant 8 : i32
+// CHECK: %[[SHUF1_8:.*]], %{{.*}} = gpu.shuffle xor %[[ADD1_4]], %[[C8_1]], %[[C16_7]] : f32
+// CHECK: %[[ADD1_8:.*]] = arith.addf %[[ADD1_4]], %[[SHUF1_8]] : f32
+// CHECK: %[[FINAL1:.*]] = arith.addf %[[ADD1_8]], %[[ACC1]] : f32
+// CHECK: vector.insert %[[FINAL1]], %[[INS0]] [0, 1] : f32 into vector<1x2xf32>
+// CHECK: gpu.return
+gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane() {
+ %src = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>}
+ dense<0.0> : vector<1x16x2xf32>
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>}
+ dense<0.0> : vector<1x2xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>
+ }
+ [1] : vector<1x16x2xf32> to vector<1x2xf32>
+ gpu.return
+}
+
// CHECK-LABEL: gpu.func @vector_extract_from_2d
// CHECK: %[[EXT:.*]] = vector.extract %{{.*}}[0] : vector<1xf32> from vector<4x1xf32>
gpu.func @vector_extract_from_2d() {
More information about the Mlir-commits
mailing list