[Mlir-commits] [mlir] [MLIR][XeGPU] Lowering 2-Dimensional Reductions of N-D Tensors into Chained 1-D Reductions (PR #186034)

Jianhui Li llvmlistbot at llvm.org
Wed Mar 18 21:09:00 PDT 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/186034

>From d6db28591c8514a03fc85accbf5f81829366b664 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 11 Mar 2026 05:33:32 +0000
Subject: [PATCH 1/4] passing 2d reduction test with layouts

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |  6 ++-
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     | 46 ++++++++-----------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  2 -
 .../test/Dialect/XeGPU/peephole-optimize.mlir | 20 +++++---
 4 files changed, 35 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index feefeb727a732..bbe2581aea99e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -489,13 +489,15 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     SmallVector<int64_t> instData(srcRank, 1);
     instData[srcRank - 2] =
         std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
-    instData[srcRank - 1] = subgroupSize;
+    instData[srcRank - 1] =
+        std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
     srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
 
   } else if (layoutKind == xegpu::LayoutKind::Lane) {
 
     SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
-    laneLayout[srcRank - 1] = subgroupSize;
+    laneLayout[srcRank - 1] =
+        std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
     laneData[srcRank - 2] =
         std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
     srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index d7a9b7ba377f9..9b55ae1aa319b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -428,10 +428,6 @@ class MultiRed2dOpPattern
   matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto sourceVecType = reductionOp.getSourceVectorType();
-    if (reductionOp.getReductionDims().size() != 2 ||
-        sourceVecType.getRank() != 2)
-      return rewriter.notifyMatchFailure(
-          reductionOp, "Expected 2D multi reduction of a 2D source");
     auto resLayout = xegpu::getDistributeLayoutAttr(reductionOp.getResult());
     // Retrieve and order dims for 1D decomposition (prefer intra-lane first).
     auto dims = llvm::to_vector(reductionOp.getReductionDims());
@@ -444,33 +440,27 @@ class MultiRed2dOpPattern
     auto loc = reductionOp.getLoc();
     auto acc = reductionOp.getAcc();
 
-    // The first reduction's dist attribute does not have the cross lane dim.
-    auto resSliceLayoutAttr = cast<xegpu::SliceAttr>(resLayout);
-    SmallVector<int64_t> dropDims{crossLaneDim};
-    auto intraLaneRedResLayout = resSliceLayoutAttr.dropSliceDims(dropDims);
-
     SmallVector<int64_t> accShape(sourceVecType.getShape());
     accShape.erase(accShape.begin() + intraLaneDim);
-    if (acc) {
-      acc = vector::BroadcastOp::create(
-          rewriter, loc,
-          VectorType::get(accShape, sourceVecType.getElementType()), acc);
-      xegpu::setDistributeLayoutAttr(
-          llvm::dyn_cast<OpResult>(acc),
-          cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
-    }
+    Type eTy = sourceVecType.getElementType();
+    Attribute eVal;
+    if (eTy.isFloat())
+      eVal = FloatAttr::get(eTy, 0.0);
+    else
+      eVal = IntegerAttr::get(eTy, 0);
+    Value const_zero = arith::ConstantOp::create(
+        rewriter, loc,
+        DenseElementsAttr::get(VectorType::get(accShape, eTy), eVal));
     Value intraLaneReduced = vector::MultiDimReductionOp::create(
-        rewriter, loc, reductionOp.getKind(), reductionOp.getSource(), acc,
-        ArrayRef<int64_t>(intraLaneDim));
-    xegpu::setDistributeLayoutAttr(
-        llvm::dyn_cast<OpResult>(intraLaneReduced),
-        cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
-
-    Value crossLaneReduced = vector::ReductionOp::create(
-        rewriter, loc, reductionOp.getKind(), intraLaneReduced, nullptr);
-    xegpu::setDistributeLayoutAttr(
-        llvm::dyn_cast<OpResult>(crossLaneReduced),
-        cast<xegpu::DistributeLayoutAttr>(resLayout));
+        rewriter, loc, reductionOp.getKind(), reductionOp.getSource(),
+        const_zero, ArrayRef<int64_t>(intraLaneDim));
+
+    // Adjust crossLaneDim after the first reduction.
+    if (crossLaneDim > intraLaneDim)
+      crossLaneDim -= 1;
+    Value crossLaneReduced = vector::MultiDimReductionOp::create(
+        rewriter, loc, reductionOp.getKind(), intraLaneReduced, acc,
+        ArrayRef<int64_t>(crossLaneDim));
     assert(crossLaneReduced.getType() == reductionOp.getResult().getType() &&
            "Type mismatch");
     rewriter.replaceOp(reductionOp, crossLaneReduced);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index b3fb00f35b167..62ba6323134a9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1135,7 +1135,6 @@ void LayoutInfoPropagation::visitLoadMatrixOp(
   if (!hasParamsOfLayoutKind(anchorLayout)) {
     VectorType resVecTy =
         llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
-    assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
     const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
     if (!uArch)
       return;
@@ -1156,7 +1155,6 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
   } else {
     VectorType srcVecTy =
         llvm::cast<VectorType>(storeMatrix.getData().getType());
-    assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
     const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
     if (!uArch)
       return;
diff --git a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
index 83fec045b9973..04f4bf9430801 100644
--- a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
+++ b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
@@ -293,14 +293,20 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
 
 // -----
 // CHECK-LABEL: gpu.func @vector_reduce_2d(
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG2:[0-9a-zA-Z]+]]: memref<256xf32>) {
-// CHECK:      %[[ACC_VEC:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK:      %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK:      %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<16xindex>
+// CHECK:      %[[ACC_VEC:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK:      %[[ACC_SCALAR:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>} 1.000000e+00 : f32
 // CHECK:      %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK:      %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0]  : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
-// CHECK:      %[[LOADED_REDUCED:.*]] = vector.multi_reduction <add>, %[[LOADED]], %[[ACC_VEC]]
-// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<4x16xf32> to vector<16xf32>
-// CHECK:      %[[LOADED_REDUCED_FOR_CROSS:.*]] = vector.reduction <add>, %[[LOADED_REDUCED]]
-// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>} : vector<16xf32> into f32
+// CHECK:      %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
+// CHECK:      %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[LOADED]], %[[ACC_VEC]] [0] : vector<4x16xf32> to vector<16xf32>
+// CHECK:      %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_SCALAR]] [0] : vector<16xf32> to f32
+// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]]
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : f32 to vector<16xf32>
+// CHECK:      xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: <{layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}>
+// CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
 gpu.module @xevm_test {
   gpu.func @vector_reduce_2d(%src: memref<4x16xf32>, %dst: memref<256xf32>) {
     %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>} 1.0 : f32

>From 16ffe18058529689baa8adefd4df564138f4678e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 12 Mar 2026 03:51:43 +0000
Subject: [PATCH 2/4] clean up

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp      | 6 ++----
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 2 ++
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index bbe2581aea99e..feefeb727a732 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -489,15 +489,13 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     SmallVector<int64_t> instData(srcRank, 1);
     instData[srcRank - 2] =
         std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
-    instData[srcRank - 1] =
-        std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
+    instData[srcRank - 1] = subgroupSize;
     srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
 
   } else if (layoutKind == xegpu::LayoutKind::Lane) {
 
     SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
-    laneLayout[srcRank - 1] =
-        std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
+    laneLayout[srcRank - 1] = subgroupSize;
     laneData[srcRank - 2] =
         std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
     srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 62ba6323134a9..b3fb00f35b167 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1135,6 +1135,7 @@ void LayoutInfoPropagation::visitLoadMatrixOp(
   if (!hasParamsOfLayoutKind(anchorLayout)) {
     VectorType resVecTy =
         llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
+    assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
     const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
     if (!uArch)
       return;
@@ -1155,6 +1156,7 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
   } else {
     VectorType srcVecTy =
         llvm::cast<VectorType>(storeMatrix.getData().getType());
+    assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
     const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
     if (!uArch)
       return;

>From b093272e0483e96dcef6ecccf80d73cf0d78b150 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 12 Mar 2026 04:16:57 +0000
Subject: [PATCH 3/4] adding test

---
 .../test/Dialect/XeGPU/peephole-optimize.mlir | 41 +++++++++++++++++++
 1 file changed, 41 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
index 04f4bf9430801..06008ccafbccd 100644
--- a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
+++ b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
@@ -329,3 +329,44 @@ gpu.module @xevm_test {
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @vector_reduce_2d_with_leading_unit_dims(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK:      %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK:      %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<16xindex>
+// CHECK:      %[[ACC_2D:.*]] = arith.constant dense<0.000000e+00> : vector<1x16xf32>
+// CHECK:      %[[ACC_1D:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>} dense<1.000000e+00> : vector<1xf32>
+// CHECK:      %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK:      %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
+// CHECK:      %[[SHAPED:.*]] = vector.shape_cast %[[LOADED]] : vector<4x16xf32> to vector<1x4x16xf32>
+// CHECK:      %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[SHAPED]], %[[ACC_2D]] [1] : vector<1x4x16xf32> to vector<1x16xf32>
+// CHECK:      %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_1D]] [1] : vector<1x16xf32> to vector<1xf32>
+// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<1xf32> to vector<16xf32>
+// CHECK:      xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+// CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
+gpu.module @xevm_test {
+  gpu.func @vector_reduce_2d_with_leading_unit_dims(%src: memref<4x16xf32>, %dst: memref<256xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>} dense<1.000000e+00> : vector<1xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<4x16xf32>
+      -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<4x16xf32>
+    %load1 = vector.broadcast %load {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}: vector<4x16xf32> to vector<1x4x16xf32>
+    %reduce = vector.multi_reduction <add>, %load1, %cst
+     {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>}
+     [1, 2] : vector<1x4x16xf32> to vector<1xf32>
+    %reduce_bcast = vector.broadcast %reduce
+     {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+     : vector<1xf32> to vector<16xf32>
+
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<16xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1> : vector<16xi1>
+
+    xegpu.store %reduce_bcast, %dst[%offset], %mask {layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
+    gpu.return
+  }
+}

>From 60b914de138da87d62c002b32ac056a377f1e852 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 19 Mar 2026 04:08:37 +0000
Subject: [PATCH 4/4] address feedback

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  8 ++
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     | 15 ++--
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 88 +------------------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 77 ++++++++++++++++
 4 files changed, 95 insertions(+), 93 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index e7cae506d9f4e..5a806799e896f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -147,6 +147,14 @@ Value lowerToVectorReductions(TypedValue<VectorType> src,
                               vector::CombiningKind kind, int64_t reductionDim,
                               Location loc, PatternRewriter &rewriter);
 
+/// Creates a constant vector filled with the neutral (identity) value for the
+/// given reduction kind. For example: 0 for ADD/OR/XOR, 1 for MUL/AND,
+/// max/min signed/unsigned int for MINSI/MINUI/MAXSI/MAXUI, and +/-infinity
+/// for float min/max operations. Returns nullptr if the element type is
+/// incompatible with the requested reduction kind.
+Value createReductionNeutralValue(OpBuilder &builder, Location loc,
+                                  VectorType type, vector::CombiningKind kind);
+
 /// Lowers cross-lane reductions to shuffle operations on a 2D vector.
 /// Extracts slices along the reduction dimension, performs subgroup reductions
 /// with shuffles across reductionSize work-items, and inserts the results back
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 9b55ae1aa319b..0ece695aed512 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -428,6 +428,8 @@ class MultiRed2dOpPattern
   matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto sourceVecType = reductionOp.getSourceVectorType();
+    if (reductionOp.getReductionDims().size() != 2)
+      return rewriter.notifyMatchFailure(reductionOp, "Expected 2D reduction");
     auto resLayout = xegpu::getDistributeLayoutAttr(reductionOp.getResult());
     // Retrieve and order dims for 1D decomposition (prefer intra-lane first).
     auto dims = llvm::to_vector(reductionOp.getReductionDims());
@@ -443,17 +445,12 @@ class MultiRed2dOpPattern
     SmallVector<int64_t> accShape(sourceVecType.getShape());
     accShape.erase(accShape.begin() + intraLaneDim);
     Type eTy = sourceVecType.getElementType();
-    Attribute eVal;
-    if (eTy.isFloat())
-      eVal = FloatAttr::get(eTy, 0.0);
-    else
-      eVal = IntegerAttr::get(eTy, 0);
-    Value const_zero = arith::ConstantOp::create(
-        rewriter, loc,
-        DenseElementsAttr::get(VectorType::get(accShape, eTy), eVal));
+    Value constNeutralVal = xegpu::createReductionNeutralValue(
+        rewriter, loc, VectorType::get(accShape, eTy), reductionOp.getKind());
+
     Value intraLaneReduced = vector::MultiDimReductionOp::create(
         rewriter, loc, reductionOp.getKind(), reductionOp.getSource(),
-        const_zero, ArrayRef<int64_t>(intraLaneDim));
+        constNeutralVal, ArrayRef<int64_t>(intraLaneDim));
 
     // Adjust crossLaneDim after the first reduction.
     if (crossLaneDim > intraLaneDim)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9f8dbc15f6422..b404ecf189842 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1198,86 +1198,6 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
-static Value createAccumulator(ConversionPatternRewriter &rewriter,
-                               Location loc, VectorType type,
-                               vector::CombiningKind kind) {
-  Type elemTy = type.getElementType();
-
-  switch (kind) {
-  case vector::CombiningKind::ADD:
-  case vector::CombiningKind::XOR:
-  case vector::CombiningKind::OR:
-    return arith::ConstantOp::create(
-        rewriter, loc, type,
-        DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
-
-  case vector::CombiningKind::MUL:
-  case vector::CombiningKind::AND:
-    return arith::ConstantOp::create(
-        rewriter, loc, type,
-        DenseElementsAttr::get(type, rewriter.getOneAttr(elemTy)));
-
-  case vector::CombiningKind::MINSI:
-    // Use max signed int value for signed integer min
-    if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
-      auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
-      return arith::ConstantOp::create(
-          rewriter, loc, type,
-          DenseElementsAttr::get(type,
-                                 rewriter.getIntegerAttr(elemTy, maxVal)));
-    }
-    return nullptr;
-
-  case vector::CombiningKind::MINUI:
-    if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
-      auto maxVal = APInt::getMaxValue(intTy.getWidth());
-      return arith::ConstantOp::create(
-          rewriter, loc, type,
-          DenseElementsAttr::get(type,
-                                 rewriter.getIntegerAttr(elemTy, maxVal)));
-    }
-    return nullptr;
-
-  case vector::CombiningKind::MAXSI:
-    if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
-      auto minVal = APInt::getSignedMinValue(intTy.getWidth());
-      return arith::ConstantOp::create(
-          rewriter, loc, type,
-          DenseElementsAttr::get(type,
-                                 rewriter.getIntegerAttr(elemTy, minVal)));
-    }
-    return nullptr;
-
-  case vector::CombiningKind::MAXUI:
-    return arith::ConstantOp::create(
-        rewriter, loc, type,
-        DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
-
-  case vector::CombiningKind::MINNUMF:
-  case vector::CombiningKind::MINIMUMF:
-    // Use +infinity for float min operations
-    if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
-      auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
-      return arith::ConstantOp::create(
-          rewriter, loc, type,
-          DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, posInf)));
-    }
-    return nullptr;
-
-  case vector::CombiningKind::MAXNUMF:
-  case vector::CombiningKind::MAXIMUMF:
-    // Use -infinity for float max operations
-    if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
-      auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true);
-      return arith::ConstantOp::create(
-          rewriter, loc, type,
-          DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, negInf)));
-    }
-    return nullptr;
-  }
-  return nullptr;
-}
-
 /// This pattern transforms vector.multi_dim_reduction operations from
 /// workgroup-level to subgroup-level execution with support for multiple
 /// reduction dimensions.
@@ -1359,8 +1279,8 @@ struct WgToSgMultiDimReductionOp
     VectorType newDstType = VectorType::get(sgDstShape, elemTy);
     for (auto sgSrc : sgSrcs) {
       // Create ZERO accumulator for local reduction
-      auto neutralLocalAcc =
-          createAccumulator(rewriter, loc, newDstType, op.getKind());
+      auto neutralLocalAcc = xegpu::createReductionNeutralValue(
+          rewriter, loc, newDstType, op.getKind());
       // Local reduction with ZERO accumulator
       auto localReduce = vector::MultiDimReductionOp::create(
           rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
@@ -1481,8 +1401,8 @@ struct WgToSgMultiDimReductionOp
         /*layout=*/nullptr);
 
     // Step 6: Perform final reduction with ZERO accumulator
-    auto neutralFinalAcc =
-        createAccumulator(rewriter, loc, newDstType, op.getKind());
+    auto neutralFinalAcc = xegpu::createReductionNeutralValue(
+        rewriter, loc, newDstType, op.getKind());
 
     auto finalReduce = vector::MultiDimReductionOp::create(
         rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(),
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index a57bf8512ddec..c30cd6edefc55 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -797,6 +797,83 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
   return reductionResult;
 }
 
+Value xegpu::createReductionNeutralValue(OpBuilder &builder, Location loc,
+                                         VectorType type,
+                                         vector::CombiningKind kind) {
+  Type elemTy = type.getElementType();
+
+  switch (kind) {
+  case vector::CombiningKind::ADD:
+  case vector::CombiningKind::XOR:
+  case vector::CombiningKind::OR:
+    return arith::ConstantOp::create(
+        builder, loc, type,
+        DenseElementsAttr::get(type, builder.getZeroAttr(elemTy)));
+
+  case vector::CombiningKind::MUL:
+  case vector::CombiningKind::AND:
+    return arith::ConstantOp::create(
+        builder, loc, type,
+        DenseElementsAttr::get(type, builder.getOneAttr(elemTy)));
+
+  case vector::CombiningKind::MINSI:
+    // Use max signed int value for signed integer min
+    if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
+      auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
+      return arith::ConstantOp::create(
+          builder, loc, type,
+          DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal)));
+    }
+    return nullptr;
+
+  case vector::CombiningKind::MINUI:
+    if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
+      auto maxVal = APInt::getMaxValue(intTy.getWidth());
+      return arith::ConstantOp::create(
+          builder, loc, type,
+          DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal)));
+    }
+    return nullptr;
+
+  case vector::CombiningKind::MAXSI:
+    if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
+      auto minVal = APInt::getSignedMinValue(intTy.getWidth());
+      return arith::ConstantOp::create(
+          builder, loc, type,
+          DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, minVal)));
+    }
+    return nullptr;
+
+  case vector::CombiningKind::MAXUI:
+    return arith::ConstantOp::create(
+        builder, loc, type,
+        DenseElementsAttr::get(type, builder.getZeroAttr(elemTy)));
+
+  case vector::CombiningKind::MINNUMF:
+  case vector::CombiningKind::MINIMUMF:
+    // Use +infinity for float min operations
+    if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
+      auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
+      return arith::ConstantOp::create(
+          builder, loc, type,
+          DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, posInf)));
+    }
+    return nullptr;
+
+  case vector::CombiningKind::MAXNUMF:
+  case vector::CombiningKind::MAXIMUMF:
+    // Use -infinity for float max operations
+    if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
+      auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true);
+      return arith::ConstantOp::create(
+          builder, loc, type,
+          DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, negInf)));
+    }
+    return nullptr;
+  }
+  return nullptr;
+}
+
 /// Explicit instantiations
 template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
                                            ArrayRef<int> candidateMultiples);



More information about the Mlir-commits mailing list