[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for reducing to scalar in sg to wi pass (PR #190193)
Nishant Patel
llvmlistbot at llvm.org
Thu Apr 2 08:23:35 PDT 2026
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/190193
None
>From e7bc139a1a305091a2d331b54e64bc4f2d68c088 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 1 Apr 2026 20:54:22 +0000
Subject: [PATCH] Add support for reducing to scalar in sg to wi pass
---
.../XeGPUSgToWiDistributeExperimental.cpp | 19 ++++++++++-
.../XeGPU/sg-to-wi-experimental-unit.mlir | 32 +++++++++++++++++++
2 files changed, 50 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 66e1f84906294..b2f51ed912dad 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -98,6 +98,9 @@ static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
// If no layout, not valid.
if (!resLayout || !resLayout.isForSubgroup())
return false;
+ // Scalar result (e.g., vector<32xf32> to f32) is valid.
+ if (op.getType().isIntOrFloat())
+ return op.getReductionDims().size() == 1;
VectorType resTy = dyn_cast<VectorType>(op.getType());
if (!resTy)
return false;
@@ -600,7 +603,21 @@ struct SgToWiMultiDimReduction
op, "only unit leading dimensions are supported for "
"multi_reduction with rank > 2");
}
- if (isReductionLaneLocal(op)) {
+ // Handle scalar result: full reduction of a distributed vector to a
+ // scalar. First do a local vector reduction, then cross-lane shuffles.
+ if (op.getType().isIntOrFloat()) {
+ auto reductionDim = reductionDims[0];
+ VectorType origSourceType = op.getSourceVectorType();
+ int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
+ // Local reduction to scalar, then cross-lane butterfly shuffles.
+ result =
+ xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
+ op.getKind(), reductionDimSize);
+ // Combine with accumulator if present.
+ if (adaptor.getAcc())
+ result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
+ result, adaptor.getAcc());
+ } else if (isReductionLaneLocal(op)) {
auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
VectorType resVecTy = dyn_cast<VectorType>(op.getType());
auto resDistVecTyOrFailure =
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 842c2375dd31d..0335105ebe7f0 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -1125,3 +1125,35 @@ gpu.func @vector_broadcast_scalar_to_vector_uniform(%laneid: index) {
gpu.return
}
}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_multi_reduction_1d_to_scalar
+// CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<32xf32>
+// CHECK: %[[DIST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<32xf32> to vector<2xf32>
+// CHECK: %[[ACC:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[LANE_RED:.*]] = vector.reduction <add>, %[[DIST]] : vector<2xf32> into f32
+// CHECK: %[[SHFL1:.*]], %{{.*}} = gpu.shuffle xor %[[LANE_RED]], %[[C1:.*]], %[[C32:.*]] : f32
+// CHECK: %[[ADD1:.*]] = arith.addf %[[LANE_RED]], %[[SHFL1]] : f32
+// CHECK: %[[SHFL2:.*]], %{{.*}} = gpu.shuffle xor %[[ADD1]], %[[C2:.*]], %[[C32:.*]] : f32
+// CHECK: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[SHFL2]] : f32
+// CHECK: %[[SHFL3:.*]], %{{.*}} = gpu.shuffle xor %[[ADD2]], %[[C4:.*]], %[[C32:.*]] : f32
+// CHECK: %[[ADD3:.*]] = arith.addf %[[ADD2]], %[[SHFL3]] : f32
+// CHECK: %[[SHFL4:.*]], %{{.*}} = gpu.shuffle xor %[[ADD3]], %[[C8:.*]], %[[C32:.*]] : f32
+// CHECK: %[[ADD4:.*]] = arith.addf %[[ADD3]], %[[SHFL4]] : f32
+// CHECK: %[[SHFL5:.*]], %{{.*}} = gpu.shuffle xor %[[ADD4]], %[[C16:.*]], %[[C32:.*]] : f32
+// CHECK: %[[ADD5:.*]] = arith.addf %[[ADD4]], %[[SHFL5]] : f32
+// CHECK: %[[FINAL:.*]] = arith.addf %[[ADD5]], %[[ACC]] : f32
+gpu.func @vector_multi_reduction_1d_to_scalar() {
+ %src = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ : () -> vector<32xf32>
+ %acc = arith.constant 0.0 : f32
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>
+ }
+ [0] : vector<32xf32> to f32
+ gpu.return
+}
+}
More information about the Mlir-commits
mailing list