[Mlir-commits] [mlir] [MLIR][XeGPU] Lower vector.multi_reduction to vector.reduction for lane local (PR #191037)
Nishant Patel
llvmlistbot at llvm.org
Wed Apr 8 12:19:31 PDT 2026
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/191037
None
>From 65f18bf269c4328bf7711359a278750af56b5a2c Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 8 Apr 2026 17:02:15 +0000
Subject: [PATCH] Lower multi_reduction to reduction for lane local
---
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 3 +-
.../XeGPUSgToWiDistributeExperimental.cpp | 20 ++++++-----
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 30 +++++++++-------
.../XeGPU/sg-to-wi-experimental-unit.mlir | 35 ++++++++++++++-----
.../Dialect/XeGPU/sg-to-wi-experimental.mlir | 3 +-
5 files changed, 58 insertions(+), 33 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 0aa2cd45088f3..f571aece0daf7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -145,7 +145,8 @@ Value subgroupReduction(Location loc, OpBuilder &builder, Value input,
Value lowerToVectorReductions(TypedValue<VectorType> src,
TypedValue<VectorType> acc,
vector::CombiningKind kind, int64_t reductionDim,
- Location loc, PatternRewriter &rewriter);
+ Location loc, PatternRewriter &rewriter,
+ bool setLayout = true);
/// Creates a constant filled with the neutral (identity) value for the
/// given reduction kind. For example: 0 for ADD/OR/XOR, 1 for MUL/AND,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index e3227c7f5b149..31ed21f8be143 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -590,15 +590,17 @@ struct SgToWiMultiDimReduction
result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
result, adaptor.getAcc());
} else if (isReductionLaneLocal(op)) {
- auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
- VectorType resVecTy = dyn_cast<VectorType>(op.getType());
- auto resDistVecTyOrFailure =
- getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
- // For lane local reduction, simply create a new MultiDimReductionOp using
- // adaptor operands and the new result type.
- result = vector::MultiDimReductionOp::create(
- rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
- adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
+ // For lane-local reduction, lower to a sequence of vector.reduction ops
+ // over 1D slices extracted from the distributed source vector. This is
+ // required so we dont have 2D source vectors at xegpu-linearize. The
+ // setLayout parameter is to make lowerToVectorReductions generic for both
+ // the old and the new pass. It will be removed once we deprecate the old
+ // pass.
+ auto reductionDim = reductionDims[0];
+ result = xegpu::lowerToVectorReductions(
+ cast<TypedValue<VectorType>>(adaptor.getSource()),
+ cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
+ reductionDim, op.getLoc(), rewriter, /*setLayout=*/false);
} else {
auto reductionDim = reductionDims[0];
VectorType sourceType = op.getSourceVectorType();
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 243581b4ce522..acae9bf1c9562 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -671,7 +671,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
TypedValue<VectorType> acc,
vector::CombiningKind kind,
int64_t reductionDim, Location loc,
- PatternRewriter &rewriter) {
+ PatternRewriter &rewriter,
+ bool setLayout) {
VectorType sourceType = src.getType();
int64_t sourceRank = sourceType.getRank();
// Expecting at least a 2D source vector. Leading dimensions (all except the
@@ -690,10 +691,13 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
Value reductionResult = arith::ConstantOp::create(
rewriter, loc, acc.getType(),
DenseElementsAttr::get(acc.getType(), zeroAttr));
- auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
- auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
- // Reduction result should have the same layout as the accumulator.
- xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+ xegpu::DistributeLayoutAttr srcLayout, accLayout;
+ if (setLayout) {
+ srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
+ accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
+ // Reduction result should have the same layout as the accumulator.
+ xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+ }
// For each slice of the source, extract the slice vector, do a reduction
// and, insert the reduced value back to the result vector.
int64_t accRank = acc.getType().getRank();
@@ -714,8 +718,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
vector::ExtractStridedSliceOp extractOp =
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
sliceSizes, strides);
- // Extract strided slice has the same layout as src.
- xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
+ if (setLayout)
+ xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
@@ -724,10 +728,10 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
VectorType::get({nSliceElements}, sourceType.getElementType()),
extractOp.getResult());
- // Shape cast output has the same layout as the accumulator. Shape cast
- // source has the same layout as the original reduction source.
- xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
- xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
+ if (setLayout) {
+ xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
+ xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
+ }
// Extract and reduction results in scalars, so no result layout is needed.
// Build multi-dim index into acc (sourceRank-1 dims, i.e. source shape with
// the reduction dim removed). Leading unit dims get index 0.
@@ -738,8 +742,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
rewriter, loc, kind, slice.getResult(), accExtract);
reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
reductionResult, accIdx);
- // Insert op should have the same layout as the accumulator.
- xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+ if (setLayout)
+ xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
}
return reductionResult;
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0335105ebe7f0..4c3727388831b 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -432,9 +432,13 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
}
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x1xf32>
-// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
-// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [0] : vector<4x1xf32> to vector<1xf32>
+// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<4x1xf32>
+// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[SLICE:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x1xf32> to vector<4x1xf32>
+// CHECK: %[[FLAT:.*]] = vector.shape_cast %[[SLICE]] : vector<4x1xf32> to vector<4xf32>
+// CHECK: %[[ACC_EL:.*]] = vector.extract %[[ACC]][0] : f32 from vector<1xf32>
+// CHECK: %[[RED:.*]] = vector.reduction <add>, %[[FLAT]], %[[ACC_EL]] : vector<4xf32> into f32
+// CHECK: vector.insert %[[RED]], %{{.*}} [0] : f32 into vector<1xf32>
// CHECK: gpu.return
gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
@@ -453,9 +457,13 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
}
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x12xf32>
-// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
-// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x12xf32> to vector<1xf32>
+// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<1x12xf32>
+// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[SLICE:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0], sizes = [1, 12], strides = [1, 1]} : vector<1x12xf32> to vector<1x12xf32>
+// CHECK: %[[FLAT:.*]] = vector.shape_cast %[[SLICE]] : vector<1x12xf32> to vector<12xf32>
+// CHECK: %[[ACC_EL:.*]] = vector.extract %[[ACC]][0] : f32 from vector<1xf32>
+// CHECK: %[[RED:.*]] = vector.reduction <add>, %[[FLAT]], %[[ACC_EL]] : vector<12xf32> into f32
+// CHECK: vector.insert %[[RED]], %{{.*}} [0] : f32 into vector<1xf32>
// CHECK: gpu.return
gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
@@ -582,9 +590,18 @@ gpu.func @constant_mask_2d() {
// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32>
-// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
-// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x16x2xf32> to vector<1x2xf32>
+// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32>
+// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 16, 1], strides = [1, 1, 1]} : vector<1x16x2xf32> to vector<1x16x1xf32>
+// CHECK: %[[F0:.*]] = vector.shape_cast %[[S0]] : vector<1x16x1xf32> to vector<16xf32>
+// CHECK: %[[A0:.*]] = vector.extract %[[ACC]][0, 0] : f32 from vector<1x2xf32>
+// CHECK: %[[R0:.*]] = vector.reduction <add>, %[[F0]], %[[A0]] : vector<16xf32> into f32
+// CHECK: %[[I0:.*]] = vector.insert %[[R0]], %{{.*}} [0, 0] : f32 into vector<1x2xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 1], sizes = [1, 16, 1], strides = [1, 1, 1]} : vector<1x16x2xf32> to vector<1x16x1xf32>
+// CHECK: %[[F1:.*]] = vector.shape_cast %[[S1]] : vector<1x16x1xf32> to vector<16xf32>
+// CHECK: %[[A1:.*]] = vector.extract %[[ACC]][0, 1] : f32 from vector<1x2xf32>
+// CHECK: %[[R1:.*]] = vector.reduction <add>, %[[F1]], %[[A1]] : vector<16xf32> into f32
+// CHECK: vector.insert %[[R1]], %[[I0]] [0, 1] : f32 into vector<1x2xf32>
// CHECK: gpu.return
gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local() {
%src = arith.constant
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
index 9febd79c7adc3..babb01c131792 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
@@ -445,7 +445,8 @@ gpu.module @xevm_module{
// -----
// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) {
-// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : vector<1xf16> to vector<16xf16>
+// CHECK: %[[RED:.*]] = vector.reduction <add>, %{{.*}}, %{{.*}} : vector<16xf16> into f16
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[RED]] : f16 to vector<16xf16>
gpu.module @xevm_module{
gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list