[Mlir-commits] [mlir] [MLIR][XeGPU] Support leading unit dims in vector.multi_reduction in sg to wi pass (PR #188767)

Nishant Patel llvmlistbot at llvm.org
Thu Mar 26 08:23:11 PDT 2026


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/188767

None

>From 6af4262cab5a33784b0de4bdf813d5b33fc1e629 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 26 Mar 2026 03:47:47 +0000
Subject: [PATCH] Support leading unit dims in reduction

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 11 +++
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 50 +++++++----
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 87 +++++++++++++++++++
 3 files changed, 130 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 0961ddfb92040..0b19ddf6163a0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -588,6 +588,17 @@ struct SgToWiMultiDimReduction
     assert(reductionDims.size() == 1 &&
            "Expecting single reduction dimension for subgroup multi "
            "reduction op");
+    // For rank > 2, ensure leading dimensions are unit.
+    VectorType sourceType = op.getSourceVectorType();
+    int64_t rank = sourceType.getRank();
+    if (rank > 2) {
+      ArrayRef<int64_t> shape = sourceType.getShape();
+      if (llvm::any_of(shape.take_front(rank - 2),
+                       [](int64_t d) { return d != 1; }))
+        return rewriter.notifyMatchFailure(
+            op, "only unit leading dimensions are supported for "
+                "multi_reduction with rank > 2");
+    }
     if (isReductionLaneLocal(op)) {
       auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
       VectorType resVecTy = dyn_cast<VectorType>(op.getType());
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index f60635830cc74..930c9e898dec4 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -750,11 +750,18 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
     TypedValue<VectorType> src, TypedValue<VectorType> acc,
     vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize,
     Location loc, PatternRewriter &rewriter) {
-  // Expecting a 2D source vector.
-  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
   VectorType sourceType = src.getType();
-  int64_t sourceH = sourceType.getShape()[0];
-  int64_t sourceW = sourceType.getShape()[1];
+  int64_t sourceRank = sourceType.getRank();
+  // Expecting at least a 2D source vector. Leading dimensions (all except the
+  // last two) must be unit.
+  assert(sourceRank >= 2 && "expected at least a 2D source vector");
+  for (int64_t i = 0; i < sourceRank - 2; ++i)
+    assert(sourceType.getShape()[i] == 1 &&
+           "expected leading dimensions to be unit");
+  int64_t rowIdx = sourceRank - 2;
+  int64_t columnIdx = sourceRank - 1;
+  int64_t sourceH = sourceType.getShape()[rowIdx];
+  int64_t sourceW = sourceType.getShape()[columnIdx];
 
   // Create a constant vector to hold the result of the reduction.
   TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
@@ -763,39 +770,46 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
       DenseElementsAttr::get(acc.getType(), zeroAttr));
 
   // nSlices is the number of reduction operations needed to reduce the entire
-  // source vector. For example, if reductionDim is 0, we are reducing across
-  // rows, and each slice is a column of the source vector. So the number of
-  // slices is the number of columns, which is sourceW.
-  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+  // source vector. For example, if reductionDim is the row dim, we are
+  // reducing across rows, and each slice is a column. So the number of slices
+  // is the number of columns, which is sourceW.
+  int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
 
   // For each slice of the source, extract the slice vector, do a reduction
   // and, insert the reduced value back to the result vector.
+  int64_t accRank = acc.getType().getRank();
   for (int i = 0; i < nSlices; ++i) {
-    SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
-    if (reductionDim == 1) {
-      sliceOffsets = {i, 0};
-      sliceSizes = {1, sourceW};
+    // Build nD offsets, sizes, and strides. Leading unit dims get
+    // offset=0, size=1. The last two dims are set based on reductionDim.
+    SmallVector<int64_t> sliceOffsets(sourceRank, 0);
+    SmallVector<int64_t> sliceSizes(sourceRank, 1);
+    SmallVector<int64_t> strides(sourceRank, 1);
+    if (reductionDim == columnIdx) {
+      sliceOffsets[rowIdx] = i;
+      sliceSizes[columnIdx] = sourceW;
     } else {
-      sliceOffsets = {0, i};
-      sliceSizes = {sourceH, 1};
+      sliceOffsets[columnIdx] = i;
+      sliceSizes[rowIdx] = sourceH;
     }
 
     vector::ExtractStridedSliceOp extractOp =
         vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
-                                              sliceSizes, {1, 1});
+                                              sliceSizes, strides);
     int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
     vector::ShapeCastOp slice = vector::ShapeCastOp::create(
         rewriter, loc,
         VectorType::get({nSliceElements}, sourceType.getElementType()),
         extractOp.getResult());
 
-    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+    SmallVector<int64_t> accIdx(accRank, 0);
+    accIdx[accRank - 1] = i;
+    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
     Value fullReduce =
         xegpu::subgroupReduction(loc, rewriter, slice, kind, reductionSize);
     fullReduce =
         vector::makeArithReduction(rewriter, loc, kind, fullReduce, accExtract);
-    reductionResult =
-        vector::InsertOp::create(rewriter, loc, fullReduce, reductionResult, i);
+    reductionResult = vector::InsertOp::create(rewriter, loc, fullReduce,
+                                               reductionResult, accIdx);
   }
   return reductionResult;
 }
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 016b393e3d8bc..07c63f5933a9a 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -461,6 +461,93 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local
+// CHECK:         %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32>
+// CHECK:         %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
+// CHECK:         %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x16x2xf32> to vector<1x2xf32>
+// CHECK:         gpu.return
+gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local() {
+    %src = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+      dense<0.0>  : vector<1x16x32xf32>
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>}
+      dense<0.0>  : vector<1x32xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
+      }
+      [1] : vector<1x16x32xf32> to vector<1x32xf32>
+  gpu.return
+}
+
+// Cross-lane 3D reduction with leading unit dim.
+// Source distributed on dim 1, reducing dim 1 => cross-lane shuffle.
+// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane
+// CHECK:         %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x2xf32>
+// CHECK:         %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
+// CHECK:         %[[SLICE0:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME:      {offsets = [0, 0, 0], sizes = [1, 1, 1], strides = [1, 1, 1]}
+// CHECK:         %[[FLAT0:.*]] = vector.shape_cast %[[SLICE0]] : vector<1x1x1xf32> to vector<1xf32>
+// CHECK:         %[[ACC0:.*]] = vector.extract %[[ACC]][0, 0] : f32 from vector<1x2xf32>
+// CHECK:         %[[RED0:.*]] = vector.reduction <add>, %[[FLAT0]] : vector<1xf32> into f32
+// CHECK:         %[[C16_0:.*]] = arith.constant 16 : i32
+// CHECK:         %[[C1:.*]] = arith.constant 1 : i32
+// CHECK:         %[[SHUF0_1:.*]], %{{.*}} = gpu.shuffle xor %[[RED0]], %[[C1]], %[[C16_0]] : f32
+// CHECK:         %[[ADD0_1:.*]] = arith.addf %[[RED0]], %[[SHUF0_1]] : f32
+// CHECK:         %[[C16_1:.*]] = arith.constant 16 : i32
+// CHECK:         %[[C2:.*]] = arith.constant 2 : i32
+// CHECK:         %[[SHUF0_2:.*]], %{{.*}} = gpu.shuffle xor %[[ADD0_1]], %[[C2]], %[[C16_1]] : f32
+// CHECK:         %[[ADD0_2:.*]] = arith.addf %[[ADD0_1]], %[[SHUF0_2]] : f32
+// CHECK:         %[[C16_2:.*]] = arith.constant 16 : i32
+// CHECK:         %[[C4:.*]] = arith.constant 4 : i32
+// CHECK:         %[[SHUF0_4:.*]], %{{.*}} = gpu.shuffle xor %[[ADD0_2]], %[[C4]], %[[C16_2]] : f32
+// CHECK:         %[[ADD0_4:.*]] = arith.addf %[[ADD0_2]], %[[SHUF0_4]] : f32
+// CHECK:         %[[C16_3:.*]] = arith.constant 16 : i32
+// CHECK:         %[[C8:.*]] = arith.constant 8 : i32
+// CHECK:         %[[SHUF0_8:.*]], %{{.*}} = gpu.shuffle xor %[[ADD0_4]], %[[C8]], %[[C16_3]] : f32
+// CHECK:         %[[ADD0_8:.*]] = arith.addf %[[ADD0_4]], %[[SHUF0_8]] : f32
+// CHECK:         %[[FINAL0:.*]] = arith.addf %[[ADD0_8]], %[[ACC0]] : f32
+// CHECK:         %[[INS0:.*]] = vector.insert %[[FINAL0]], %{{.*}} [0, 0] : f32 into vector<1x2xf32>
+// CHECK:         %[[SLICE1:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME:      {offsets = [0, 0, 1], sizes = [1, 1, 1], strides = [1, 1, 1]}
+// CHECK:         %[[FLAT1:.*]] = vector.shape_cast %[[SLICE1]] : vector<1x1x1xf32> to vector<1xf32>
+// CHECK:         %[[ACC1:.*]] = vector.extract %[[ACC]][0, 1] : f32 from vector<1x2xf32>
+// CHECK:         %[[RED1:.*]] = vector.reduction <add>, %[[FLAT1]] : vector<1xf32> into f32
+// CHECK:         %[[C16_4:.*]] = arith.constant 16 : i32
+// CHECK:         %[[C1_1:.*]] = arith.constant 1 : i32
+// CHECK:         %[[SHUF1_1:.*]], %{{.*}} = gpu.shuffle xor %[[RED1]], %[[C1_1]], %[[C16_4]] : f32
+// CHECK:         %[[ADD1_1:.*]] = arith.addf %[[RED1]], %[[SHUF1_1]] : f32
+// CHECK:         %[[C16_5:.*]] = arith.constant 16 : i32
+// CHECK:         %[[C2_1:.*]] = arith.constant 2 : i32
+// CHECK:         %[[SHUF1_2:.*]], %{{.*}} = gpu.shuffle xor %[[ADD1_1]], %[[C2_1]], %[[C16_5]] : f32
+// CHECK:         %[[ADD1_2:.*]] = arith.addf %[[ADD1_1]], %[[SHUF1_2]] : f32
+// CHECK:         %[[C16_6:.*]] = arith.constant 16 : i32
+// CHECK:         %[[C4_1:.*]] = arith.constant 4 : i32
+// CHECK:         %[[SHUF1_4:.*]], %{{.*}} = gpu.shuffle xor %[[ADD1_2]], %[[C4_1]], %[[C16_6]] : f32
+// CHECK:         %[[ADD1_4:.*]] = arith.addf %[[ADD1_2]], %[[SHUF1_4]] : f32
+// CHECK:         %[[C16_7:.*]] = arith.constant 16 : i32
+// CHECK:         %[[C8_1:.*]] = arith.constant 8 : i32
+// CHECK:         %[[SHUF1_8:.*]], %{{.*}} = gpu.shuffle xor %[[ADD1_4]], %[[C8_1]], %[[C16_7]] : f32
+// CHECK:         %[[ADD1_8:.*]] = arith.addf %[[ADD1_4]], %[[SHUF1_8]] : f32
+// CHECK:         %[[FINAL1:.*]] = arith.addf %[[ADD1_8]], %[[ACC1]] : f32
+// CHECK:         vector.insert %[[FINAL1]], %[[INS0]] [0, 1] : f32 into vector<1x2xf32>
+// CHECK:         gpu.return
+gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane() {
+    %src = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>}
+      dense<0.0>  : vector<1x16x2xf32>
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>}
+      dense<0.0>  : vector<1x2xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>
+      }
+      [1] : vector<1x16x2xf32> to vector<1x2xf32>
+  gpu.return
+}
+
 // CHECK-LABEL: gpu.func @vector_extract_from_2d
 // CHECK: %[[EXT:.*]] = vector.extract %{{.*}}[0] : vector<1xf32> from vector<4x1xf32>
 gpu.func @vector_extract_from_2d() {



More information about the Mlir-commits mailing list