[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for lowering vector.multi_reduction to scalar in Wg to Sg (PR #188623)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 25 14:55:53 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/188623.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h (+6-5) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+36-21) 
- (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+29-46) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+35-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 5a806799e896f..0aa2cd45088f3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -147,13 +147,14 @@ Value lowerToVectorReductions(TypedValue<VectorType> src,
                               vector::CombiningKind kind, int64_t reductionDim,
                               Location loc, PatternRewriter &rewriter);
 
-/// Creates a constant vector filled with the neutral (identity) value for the
+/// Creates a constant filled with the neutral (identity) value for the
 /// given reduction kind. For example: 0 for ADD/OR/XOR, 1 for MUL/AND,
 /// max/min signed/unsigned int for MINSI/MINUI/MAXSI/MAXUI, and +/-infinity
-/// for float min/max operations. Returns nullptr if the element type is
-/// incompatible with the requested reduction kind.
-Value createReductionNeutralValue(OpBuilder &builder, Location loc,
-                                  VectorType type, vector::CombiningKind kind);
+/// for float min/max operations. If \p type is a VectorType, returns a splat
+/// vector constant; otherwise returns a scalar constant. Returns nullptr if
+/// the element type is incompatible with the requested reduction kind.
+Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type,
+                                  vector::CombiningKind kind);
 
 /// Lowers cross-lane reductions to shuffle operations on a 2D vector.
 /// Extracts slices along the reduction dimension, performs subgroup reductions
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6dea94c0c5de3..3d1d1ca3ecf98 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1233,13 +1233,13 @@ struct WgToSgMultiDimReductionOp
     Location loc = op.getLoc();
 
     VectorType srcType = op.getSourceVectorType();
-    VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
-    if (!dstType)
-      return failure();
+    Type resultTy = op.getResult().getType();
+    VectorType dstVecType = dyn_cast<VectorType>(resultTy);
+    bool isScalarResult = !dstVecType;
 
     auto originalSrcShape = srcType.getShape();
-    auto originalDstShape = dstType.getShape();
     int srcVecRank = originalSrcShape.size();
+    Type elemTy = srcType.getElementType();
 
     xegpu::DistributeLayoutAttr layout =
         xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
@@ -1258,25 +1258,33 @@ struct WgToSgMultiDimReductionOp
       return rewriter.notifyMatchFailure(
           op, "Reduction should have SliceAttr layout");
 
-    Type elemTy = dstType.getElementType();
-
-    // Step 1: perform local subgroup reductions with ZERO accumulator
+    // Step 1: perform local subgroup reductions with neutral accumulator
     SmallVector<Value> localReductions;
-    SmallVector<int64_t> sgDstShape =
-        getSgShapeAndCount(originalDstShape, layout).first;
     auto sgSrcs = adaptor.getSource();
     auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
     SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
                                     sgSrcType.getShape().end());
 
-    VectorType newDstType = VectorType::get(sgDstShape, elemTy);
+    // Determine the SG-level destination type.
+    // For scalar results (all dims reduced), the sg result is also scalar.
+    // For vector results, compute the sg destination shape from layout.
+    Type sgDstType;
+    if (dstVecType) {
+      auto originalDstShape = dstVecType.getShape();
+      SmallVector<int64_t> sgDstShape =
+          getSgShapeAndCount(originalDstShape, layout).first;
+      sgDstType = VectorType::get(sgDstShape, elemTy);
+    } else {
+      sgDstType = elemTy;
+    }
+
     for (auto sgSrc : sgSrcs) {
-      // Create ZERO accumulator for local reduction
-      auto neutralLocalAcc = xegpu::createReductionNeutralValue(
-          rewriter, loc, newDstType, op.getKind());
-      // Local reduction with ZERO accumulator
+      // Create neutral accumulator for local reduction
+      Value neutralLocalAcc = xegpu::createReductionNeutralValue(
+          rewriter, loc, sgDstType, op.getKind());
+      // Local reduction with neutral accumulator
       auto localReduce = vector::MultiDimReductionOp::create(
-          rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
+          rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
           reductionDims);
       localReductions.push_back(localReduce.getResult());
     }
@@ -1310,8 +1318,15 @@ struct WgToSgMultiDimReductionOp
     for (int64_t dim : reductionDims)
       slmStoreDataShape[dim] = 1;
     VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
-    Value slmStoreData = vector::ShapeCastOp::create(
-        rewriter, loc, slmStoreDataType, localReductions[0]);
+    Value slmStoreData;
+    if (isScalarResult) {
+      // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
+      slmStoreData = vector::BroadcastOp::create(
+          rewriter, loc, slmStoreDataType, localReductions[0]);
+    } else {
+      slmStoreData = vector::ShapeCastOp::create(
+          rewriter, loc, slmStoreDataType, localReductions[0]);
+    }
 
     SmallVector<int64_t> slmShape(originalSrcShape.begin(),
                                   originalSrcShape.end());
@@ -1393,12 +1408,12 @@ struct WgToSgMultiDimReductionOp
         rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
         /*layout=*/nullptr);
 
-    // Step 6: Perform final reduction with ZERO accumulator
-    auto neutralFinalAcc = xegpu::createReductionNeutralValue(
-        rewriter, loc, newDstType, op.getKind());
+    // Step 6: Perform final reduction with neutral accumulator
+    Value neutralFinalAcc = xegpu::createReductionNeutralValue(
+        rewriter, loc, sgDstType, op.getKind());
 
     auto finalReduce = vector::MultiDimReductionOp::create(
-        rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(),
+        rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(),
         neutralFinalAcc, reductionDims);
 
     // Step 7: Add the original accumulator at the end
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index f60635830cc74..6c902f725ca0c 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -801,77 +801,60 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
 }
 
 Value xegpu::createReductionNeutralValue(OpBuilder &builder, Location loc,
-                                         VectorType type,
+                                         Type type,
                                          vector::CombiningKind kind) {
-  Type elemTy = type.getElementType();
+  auto vecTy = dyn_cast<VectorType>(type);
+  Type elemTy = vecTy ? vecTy.getElementType() : type;
+
+  // Helper to create either a splat vector or scalar constant from an attr.
+  auto makeConst = [&](Attribute scalarAttr) -> Value {
+    if (vecTy)
+      return arith::ConstantOp::create(
+          builder, loc, vecTy, DenseElementsAttr::get(vecTy, scalarAttr));
+    return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(scalarAttr));
+  };
 
   switch (kind) {
   case vector::CombiningKind::ADD:
   case vector::CombiningKind::XOR:
   case vector::CombiningKind::OR:
-    return arith::ConstantOp::create(
-        builder, loc, type,
-        DenseElementsAttr::get(type, builder.getZeroAttr(elemTy)));
+  case vector::CombiningKind::MAXUI:
+    return makeConst(builder.getZeroAttr(elemTy));
 
   case vector::CombiningKind::MUL:
   case vector::CombiningKind::AND:
-    return arith::ConstantOp::create(
-        builder, loc, type,
-        DenseElementsAttr::get(type, builder.getOneAttr(elemTy)));
+    return makeConst(builder.getOneAttr(elemTy));
 
   case vector::CombiningKind::MINSI:
-    // Use max signed int value for signed integer min
-    if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
-      auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
-      return arith::ConstantOp::create(
-          builder, loc, type,
-          DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal)));
-    }
+    if (auto intTy = dyn_cast<IntegerType>(elemTy))
+      return makeConst(builder.getIntegerAttr(
+          elemTy, APInt::getSignedMaxValue(intTy.getWidth())));
     return nullptr;
 
   case vector::CombiningKind::MINUI:
-    if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
-      auto maxVal = APInt::getMaxValue(intTy.getWidth());
-      return arith::ConstantOp::create(
-          builder, loc, type,
-          DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal)));
-    }
+    if (auto intTy = dyn_cast<IntegerType>(elemTy))
+      return makeConst(
+          builder.getIntegerAttr(elemTy, APInt::getMaxValue(intTy.getWidth())));
     return nullptr;
 
   case vector::CombiningKind::MAXSI:
-    if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
-      auto minVal = APInt::getSignedMinValue(intTy.getWidth());
-      return arith::ConstantOp::create(
-          builder, loc, type,
-          DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, minVal)));
-    }
+    if (auto intTy = dyn_cast<IntegerType>(elemTy))
+      return makeConst(builder.getIntegerAttr(
+          elemTy, APInt::getSignedMinValue(intTy.getWidth())));
     return nullptr;
 
-  case vector::CombiningKind::MAXUI:
-    return arith::ConstantOp::create(
-        builder, loc, type,
-        DenseElementsAttr::get(type, builder.getZeroAttr(elemTy)));
-
   case vector::CombiningKind::MINNUMF:
   case vector::CombiningKind::MINIMUMF:
-    // Use +infinity for float min operations
-    if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
-      auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
-      return arith::ConstantOp::create(
-          builder, loc, type,
-          DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, posInf)));
-    }
+    if (auto floatTy = dyn_cast<FloatType>(elemTy))
+      return makeConst(builder.getFloatAttr(
+          elemTy, APFloat::getInf(floatTy.getFloatSemantics())));
     return nullptr;
 
   case vector::CombiningKind::MAXNUMF:
   case vector::CombiningKind::MAXIMUMF:
-    // Use -infinity for float max operations
-    if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
-      auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true);
-      return arith::ConstantOp::create(
-          builder, loc, type,
-          DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, negInf)));
-    }
+    if (auto floatTy = dyn_cast<FloatType>(elemTy))
+      return makeConst(builder.getFloatAttr(
+          elemTy, APFloat::getInf(floatTy.getFloatSemantics(), true)));
     return nullptr;
   }
   return nullptr;
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index ecc5fe3dd75e0..950d9ba66f0cc 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
-// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 32)>
-// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 32)>
-// CHECK-DAG: #map2 = affine_map<()[s0] -> (0)>
-// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 floordiv 4)>
-// CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK-DAG: #map2 = affine_map<()[s0] -> (s0 floordiv 32)>
+// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 mod 32)>
+// CHECK-DAG: #map4 = affine_map<()[s0] -> (0)>
 // CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)>
 // CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)>
 // CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)>
@@ -412,6 +412,33 @@ gpu.module @test_distribution {
       gpu.return
     }
 
+  // CHECK-LABEL: gpu.func @vector_reduce_scalar_cross_sg
+  // CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf32>)
+  // CHECK-DAG: %[[CST:.*]] = arith.constant {{.*}} 0.000000e+00 : f32
+  // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x8xf32> -> vector<8x8xf32>
+  // CHECK-DAG: %[[CST_ACC:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[LOCAL:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_ACC]] [0, 1] : vector<8x8xf32> to f32
+  // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[LOCAL]] : f32 to vector<1x1xf32>
+  // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<64xi8, 3>
+  // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<64xi8, 3> -> !xegpu.mem_desc<4x4xf32>
+  // CHECK-DAG: xegpu.store_matrix %[[BCAST]], %[[MEM_DESC]]{{.*}} : vector<1x1xf32>, !xegpu.mem_desc<4x4xf32>
+  // CHECK-DAG: gpu.barrier
+  // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} -> vector<4x4xf32>
+  // CHECK-DAG: %[[CST_FINAL:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[FINAL:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_FINAL]] [0, 1] : vector<4x4xf32> to f32
+  // CHECK-DAG: arith.addf %[[FINAL]], %[[CST]] : f32
+  gpu.func @vector_reduce_scalar_cross_sg(%src: memref<32x32xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>, dims = [0, 1]>} 0.0 : f32
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xf32>
+      -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>>
+    %load = xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>}
+      : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>>
+      -> vector<32x32xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 4], sg_data = [8, 8]>, dims = [0, 1]>} [0, 1]
+      : vector<32x32xf32> to f32
+    gpu.return
+  }
+
   // CHECK-LABEL: vector_step_op
   gpu.func @vector_step_op_slice_attr() {
     //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
@@ -654,9 +681,9 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
     // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<1x32x32xf32>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map()[%[[SGID]]]
-    // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map1()[%[[SGID]]]
-    // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map2()[%[[SGID]]]
+    // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map2()[%[[SGID]]]
+    // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map3()[%[[SGID]]]
+    // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map4()[%[[SGID]]]
     // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[AFF0]], %[[C1A:.*]] : index
     // CHECK-DAG: %[[COL0:.*]] = arith.muli %[[AFF1:.*]], %[[C1B:.*]] : index
     // CHECK-DAG: %[[COL1:.*]] = arith.muli %[[AFF2]], %[[C32A:.*]] : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/188623


More information about the Mlir-commits mailing list