[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for lowering vector.multi_reduction to scalar in Wg to Sg (PR #188623)
Nishant Patel
llvmlistbot at llvm.org
Wed Mar 25 14:55:21 PDT 2026
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/188623
None
>From cbf2305ebbe5d0791751e319f91d0e749b688bb4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 25 Mar 2026 21:34:57 +0000
Subject: [PATCH] Add support for reduction to scalar in Wg to Sg
---
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 11 +--
.../Transforms/XeGPUWgToSgDistribute.cpp | 57 ++++++++------
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 75 +++++++------------
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 43 +++++++++--
4 files changed, 106 insertions(+), 80 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 5a806799e896f..0aa2cd45088f3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -147,13 +147,14 @@ Value lowerToVectorReductions(TypedValue<VectorType> src,
vector::CombiningKind kind, int64_t reductionDim,
Location loc, PatternRewriter &rewriter);
-/// Creates a constant vector filled with the neutral (identity) value for the
+/// Creates a constant filled with the neutral (identity) value for the
/// given reduction kind. For example: 0 for ADD/OR/XOR, 1 for MUL/AND,
/// max/min signed/unsigned int for MINSI/MINUI/MAXSI/MAXUI, and +/-infinity
-/// for float min/max operations. Returns nullptr if the element type is
-/// incompatible with the requested reduction kind.
-Value createReductionNeutralValue(OpBuilder &builder, Location loc,
- VectorType type, vector::CombiningKind kind);
+/// for float min/max operations. If \p type is a VectorType, returns a splat
+/// vector constant; otherwise returns a scalar constant. Returns nullptr if
+/// the element type is incompatible with the requested reduction kind.
+Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type,
+ vector::CombiningKind kind);
/// Lowers cross-lane reductions to shuffle operations on a 2D vector.
/// Extracts slices along the reduction dimension, performs subgroup reductions
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6dea94c0c5de3..3d1d1ca3ecf98 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1233,13 +1233,13 @@ struct WgToSgMultiDimReductionOp
Location loc = op.getLoc();
VectorType srcType = op.getSourceVectorType();
- VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
- if (!dstType)
- return failure();
+ Type resultTy = op.getResult().getType();
+ VectorType dstVecType = dyn_cast<VectorType>(resultTy);
+ bool isScalarResult = !dstVecType;
auto originalSrcShape = srcType.getShape();
- auto originalDstShape = dstType.getShape();
int srcVecRank = originalSrcShape.size();
+ Type elemTy = srcType.getElementType();
xegpu::DistributeLayoutAttr layout =
xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
@@ -1258,25 +1258,33 @@ struct WgToSgMultiDimReductionOp
return rewriter.notifyMatchFailure(
op, "Reduction should have SliceAttr layout");
- Type elemTy = dstType.getElementType();
-
- // Step 1: perform local subgroup reductions with ZERO accumulator
+ // Step 1: perform local subgroup reductions with neutral accumulator
SmallVector<Value> localReductions;
- SmallVector<int64_t> sgDstShape =
- getSgShapeAndCount(originalDstShape, layout).first;
auto sgSrcs = adaptor.getSource();
auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
sgSrcType.getShape().end());
- VectorType newDstType = VectorType::get(sgDstShape, elemTy);
+ // Determine the SG-level destination type.
+ // For scalar results (all dims reduced), the sg result is also scalar.
+ // For vector results, compute the sg destination shape from layout.
+ Type sgDstType;
+ if (dstVecType) {
+ auto originalDstShape = dstVecType.getShape();
+ SmallVector<int64_t> sgDstShape =
+ getSgShapeAndCount(originalDstShape, layout).first;
+ sgDstType = VectorType::get(sgDstShape, elemTy);
+ } else {
+ sgDstType = elemTy;
+ }
+
for (auto sgSrc : sgSrcs) {
- // Create ZERO accumulator for local reduction
- auto neutralLocalAcc = xegpu::createReductionNeutralValue(
- rewriter, loc, newDstType, op.getKind());
- // Local reduction with ZERO accumulator
+ // Create neutral accumulator for local reduction
+ Value neutralLocalAcc = xegpu::createReductionNeutralValue(
+ rewriter, loc, sgDstType, op.getKind());
+ // Local reduction with neutral accumulator
auto localReduce = vector::MultiDimReductionOp::create(
- rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
+ rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
reductionDims);
localReductions.push_back(localReduce.getResult());
}
@@ -1310,8 +1318,15 @@ struct WgToSgMultiDimReductionOp
for (int64_t dim : reductionDims)
slmStoreDataShape[dim] = 1;
VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
- Value slmStoreData = vector::ShapeCastOp::create(
- rewriter, loc, slmStoreDataType, localReductions[0]);
+ Value slmStoreData;
+ if (isScalarResult) {
+ // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
+ slmStoreData = vector::BroadcastOp::create(
+ rewriter, loc, slmStoreDataType, localReductions[0]);
+ } else {
+ slmStoreData = vector::ShapeCastOp::create(
+ rewriter, loc, slmStoreDataType, localReductions[0]);
+ }
SmallVector<int64_t> slmShape(originalSrcShape.begin(),
originalSrcShape.end());
@@ -1393,12 +1408,12 @@ struct WgToSgMultiDimReductionOp
rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
/*layout=*/nullptr);
- // Step 6: Perform final reduction with ZERO accumulator
- auto neutralFinalAcc = xegpu::createReductionNeutralValue(
- rewriter, loc, newDstType, op.getKind());
+ // Step 6: Perform final reduction with neutral accumulator
+ Value neutralFinalAcc = xegpu::createReductionNeutralValue(
+ rewriter, loc, sgDstType, op.getKind());
auto finalReduce = vector::MultiDimReductionOp::create(
- rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(),
+ rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(),
neutralFinalAcc, reductionDims);
// Step 7: Add the original accumulator at the end
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index f60635830cc74..6c902f725ca0c 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -801,77 +801,60 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
}
Value xegpu::createReductionNeutralValue(OpBuilder &builder, Location loc,
- VectorType type,
+ Type type,
vector::CombiningKind kind) {
- Type elemTy = type.getElementType();
+ auto vecTy = dyn_cast<VectorType>(type);
+ Type elemTy = vecTy ? vecTy.getElementType() : type;
+
+ // Helper to create either a splat vector or scalar constant from an attr.
+ auto makeConst = [&](Attribute scalarAttr) -> Value {
+ if (vecTy)
+ return arith::ConstantOp::create(
+ builder, loc, vecTy, DenseElementsAttr::get(vecTy, scalarAttr));
+ return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(scalarAttr));
+ };
switch (kind) {
case vector::CombiningKind::ADD:
case vector::CombiningKind::XOR:
case vector::CombiningKind::OR:
- return arith::ConstantOp::create(
- builder, loc, type,
- DenseElementsAttr::get(type, builder.getZeroAttr(elemTy)));
+ case vector::CombiningKind::MAXUI:
+ return makeConst(builder.getZeroAttr(elemTy));
case vector::CombiningKind::MUL:
case vector::CombiningKind::AND:
- return arith::ConstantOp::create(
- builder, loc, type,
- DenseElementsAttr::get(type, builder.getOneAttr(elemTy)));
+ return makeConst(builder.getOneAttr(elemTy));
case vector::CombiningKind::MINSI:
- // Use max signed int value for signed integer min
- if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
- auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
- return arith::ConstantOp::create(
- builder, loc, type,
- DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal)));
- }
+ if (auto intTy = dyn_cast<IntegerType>(elemTy))
+ return makeConst(builder.getIntegerAttr(
+ elemTy, APInt::getSignedMaxValue(intTy.getWidth())));
return nullptr;
case vector::CombiningKind::MINUI:
- if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
- auto maxVal = APInt::getMaxValue(intTy.getWidth());
- return arith::ConstantOp::create(
- builder, loc, type,
- DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal)));
- }
+ if (auto intTy = dyn_cast<IntegerType>(elemTy))
+ return makeConst(
+ builder.getIntegerAttr(elemTy, APInt::getMaxValue(intTy.getWidth())));
return nullptr;
case vector::CombiningKind::MAXSI:
- if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
- auto minVal = APInt::getSignedMinValue(intTy.getWidth());
- return arith::ConstantOp::create(
- builder, loc, type,
- DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, minVal)));
- }
+ if (auto intTy = dyn_cast<IntegerType>(elemTy))
+ return makeConst(builder.getIntegerAttr(
+ elemTy, APInt::getSignedMinValue(intTy.getWidth())));
return nullptr;
- case vector::CombiningKind::MAXUI:
- return arith::ConstantOp::create(
- builder, loc, type,
- DenseElementsAttr::get(type, builder.getZeroAttr(elemTy)));
-
case vector::CombiningKind::MINNUMF:
case vector::CombiningKind::MINIMUMF:
- // Use +infinity for float min operations
- if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
- auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
- return arith::ConstantOp::create(
- builder, loc, type,
- DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, posInf)));
- }
+ if (auto floatTy = dyn_cast<FloatType>(elemTy))
+ return makeConst(builder.getFloatAttr(
+ elemTy, APFloat::getInf(floatTy.getFloatSemantics())));
return nullptr;
case vector::CombiningKind::MAXNUMF:
case vector::CombiningKind::MAXIMUMF:
- // Use -infinity for float max operations
- if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
- auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true);
- return arith::ConstantOp::create(
- builder, loc, type,
- DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, negInf)));
- }
+ if (auto floatTy = dyn_cast<FloatType>(elemTy))
+ return makeConst(builder.getFloatAttr(
+ elemTy, APFloat::getInf(floatTy.getFloatSemantics(), true)));
return nullptr;
}
return nullptr;
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index ecc5fe3dd75e0..950d9ba66f0cc 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
-// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 32)>
-// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 32)>
-// CHECK-DAG: #map2 = affine_map<()[s0] -> (0)>
-// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 floordiv 4)>
-// CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK-DAG: #map2 = affine_map<()[s0] -> (s0 floordiv 32)>
+// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 mod 32)>
+// CHECK-DAG: #map4 = affine_map<()[s0] -> (0)>
// CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)>
// CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)>
// CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)>
@@ -412,6 +412,33 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: gpu.func @vector_reduce_scalar_cross_sg
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf32>)
+ // CHECK-DAG: %[[CST:.*]] = arith.constant {{.*}} 0.000000e+00 : f32
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x8xf32> -> vector<8x8xf32>
+ // CHECK-DAG: %[[CST_ACC:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[LOCAL:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_ACC]] [0, 1] : vector<8x8xf32> to f32
+ // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[LOCAL]] : f32 to vector<1x1xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<64xi8, 3>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<64xi8, 3> -> !xegpu.mem_desc<4x4xf32>
+ // CHECK-DAG: xegpu.store_matrix %[[BCAST]], %[[MEM_DESC]]{{.*}} : vector<1x1xf32>, !xegpu.mem_desc<4x4xf32>
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} -> vector<4x4xf32>
+ // CHECK-DAG: %[[CST_FINAL:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[FINAL:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_FINAL]] [0, 1] : vector<4x4xf32> to f32
+ // CHECK-DAG: arith.addf %[[FINAL]], %[[CST]] : f32
+ gpu.func @vector_reduce_scalar_cross_sg(%src: memref<32x32xf32>) {
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>, dims = [0, 1]>} 0.0 : f32
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xf32>
+ -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>>
+ %load = xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>}
+ : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>>
+ -> vector<32x32xf32>
+ %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>, dims = [0, 1]>} [0, 1]
+ : vector<32x32xf32> to f32
+ gpu.return
+ }
+
// CHECK-LABEL: vector_step_op
gpu.func @vector_step_op_slice_attr() {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
@@ -654,9 +681,9 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<1x32x32xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map2()[%[[SGID]]]
+ // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map2()[%[[SGID]]]
+ // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map3()[%[[SGID]]]
+ // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map4()[%[[SGID]]]
// CHECK-DAG: %[[ROW:.*]] = arith.muli %[[AFF0]], %[[C1A:.*]] : index
// CHECK-DAG: %[[COL0:.*]] = arith.muli %[[AFF1:.*]], %[[C1B:.*]] : index
// CHECK-DAG: %[[COL1:.*]] = arith.muli %[[AFF2]], %[[C32A:.*]] : index
More information about the Mlir-commits
mailing list