[Mlir-commits] [mlir] [MLIR][XeGPU] Support leading unit dim (PR #185110)

Nishant Patel llvmlistbot at llvm.org
Fri Mar 6 13:09:31 PST 2026


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/185110

None

>From 22255e97a5cfaa1b615e1527852f55d253742c68 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 6 Mar 2026 19:29:32 +0000
Subject: [PATCH] Support leading unit dim

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 44 ++++++++++++------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 46 +++++++++++++------
 .../XeGPU/subgroup-distribute-unit.mlir       | 41 +++++++++++++++++
 3 files changed, 102 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 38bc95d39c2c6..90d61eec48369 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1321,17 +1321,31 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
         cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
     unsigned operandIdx = yieldOperand->getOperandNumber();
     VectorType sourceType = reductionOp.getSourceVectorType();
-    // Only 2D vectors are supported.
-    if (sourceType.getRank() != 2)
+    int64_t sourceRank = sourceType.getRank();
+    // Need at least a 2D source vector.
+    if (sourceRank < 2)
       return rewriter.notifyMatchFailure(warpOp,
-                                         "Only 2D reductions are supported.");
+                                         "Only 2D+ reductions are supported.");
+    // Leading dimensions (first rank-2) must be unit (size 1).
+    for (int64_t i = 0; i < sourceRank - 2; ++i) {
+      if (sourceType.getShape()[i] != 1)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Only unit dimensions allowed for the leading dimensions.");
+    }
+    // Effective dimension indices (last 2 dims of the source).
+    int64_t dim0Idx = sourceRank - 2;
+    int64_t dim1Idx = sourceRank - 1;
     ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
-    // Only 1 reduction dimension supported. This also ensures that the result
-    // is vector type.
-    if (reductionDims.size() != 1)
+    // Find non-trivial effective reduction dims among the last 2 dims.
+    SmallVector<int64_t> effectiveReductionDims;
+    for (int64_t d : reductionDims) {
+      if ((d == dim0Idx || d == dim1Idx) && sourceType.getShape()[d] > 1)
+        effectiveReductionDims.push_back(d);
+    }
+    if (effectiveReductionDims.size() != 1)
       return rewriter.notifyMatchFailure(
-          warpOp, "Only 1 reduction dimension is supported.");
-    int64_t reductionDim = reductionDims[0];
+          warpOp, "Only 1 non-trivial effective reduction dim is supported.");
+    int64_t reductionDim = effectiveReductionDims[0];
     VectorType distributedResultType =
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
     VectorType resultType = cast<VectorType>(reductionOp.getType());
@@ -1344,15 +1358,16 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
       return rewriter.notifyMatchFailure(
           warpOp, "Failed to distribute the source vector type.");
     VectorType sourceDistType = sourceDistTypeOrFailure.value();
-    // Only single dimension distribution is supported.
+    // Only single dimension distribution among the last 2 dims is supported.
     bool dim0Distributed =
-        sourceDistType.getShape()[0] != sourceType.getShape()[0];
+        sourceDistType.getShape()[dim0Idx] != sourceType.getShape()[dim0Idx];
     bool dim1Distributed =
-        sourceDistType.getShape()[1] != sourceType.getShape()[1];
+        sourceDistType.getShape()[dim1Idx] != sourceType.getShape()[dim1Idx];
     if (dim0Distributed && dim1Distributed)
       return rewriter.notifyMatchFailure(
           warpOp, "Expecting source to be distributed in a single dimension.");
-    int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
+    int64_t sourceDistDim =
+        dim0Distributed ? dim0Idx : (dim1Distributed ? dim1Idx : -1);
     if (sourceDistDim == -1)
       return rewriter.notifyMatchFailure(
           warpOp, "Expecting a distributed source vector.");
@@ -1371,8 +1386,9 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
     // |  dim-1 distributed   |       0        | distributed    |
     // |  dim-1 distributed   |       1        | broadcasted    |
 
-    bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
-                                (sourceDistDim == 1 && reductionDim == 0);
+    bool isReductionLaneLocal =
+        (sourceDistDim == dim0Idx && reductionDim == dim1Idx) ||
+        (sourceDistDim == dim1Idx && reductionDim == dim0Idx);
     if (isReductionLaneLocal && !resultDistributed)
       return rewriter.notifyMatchFailure(
           warpOp, "Expecting a distributed result for lane-local reduction.");
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 3271e73e0b571..c31eda2aa2a60 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -671,12 +671,19 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
                                      vector::CombiningKind kind,
                                      int64_t reductionDim, Location loc,
                                      PatternRewriter &rewriter) {
-  // Expecting a 2D source vector.
-  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
   VectorType sourceType = src.getType();
-  int64_t sourceH = sourceType.getShape()[0];
-  int64_t sourceW = sourceType.getShape()[1];
-  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+  int64_t sourceRank = sourceType.getRank();
+  // Expecting at least a 2D source vector. Leading dimensions (all except the
+  // last two) must be unit.
+  assert(sourceRank >= 2 && "expected at least a 2D source vector");
+  for (int64_t i = 0; i < sourceRank - 2; ++i)
+    assert(sourceType.getShape()[i] == 1 &&
+           "expected leading dimensions to be unit");
+  int64_t dim0Idx = sourceRank - 2;
+  int64_t dim1Idx = sourceRank - 1;
+  int64_t sourceH = sourceType.getShape()[dim0Idx];
+  int64_t sourceW = sourceType.getShape()[dim1Idx];
+  int nSlices = (reductionDim == dim0Idx) ? sourceW : sourceH;
   // Create a constant vector to hold the result of the reduction.
   TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
   Value reductionResult = arith::ConstantOp::create(
@@ -688,19 +695,24 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
   xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
   // For each slice of the source, extract the slice vector, do a reduction
   // and, insert the reduced value back to the result vector.
+  int64_t accRank = acc.getType().getRank();
   for (int i = 0; i < nSlices; ++i) {
-    SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
-    if (reductionDim == 1) {
-      sliceOffsets = {i, 0};
-      sliceSizes = {1, sourceW};
+    // Build nD offsets, sizes, and strides. Leading unit dims get
+    // offset=0, size=1. The last two dims are set based on reductionDim.
+    SmallVector<int64_t> sliceOffsets(sourceRank, 0);
+    SmallVector<int64_t> sliceSizes(sourceRank, 1);
+    SmallVector<int64_t> strides(sourceRank, 1);
+    if (reductionDim == dim1Idx) {
+      sliceOffsets[dim0Idx] = i;
+      sliceSizes[dim1Idx] = sourceW;
     } else {
-      sliceOffsets = {0, i};
-      sliceSizes = {sourceH, 1};
+      sliceOffsets[dim1Idx] = i;
+      sliceSizes[dim0Idx] = sourceH;
     }
 
     vector::ExtractStridedSliceOp extractOp =
         vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
-                                              sliceSizes, {1, 1});
+                                              sliceSizes, strides);
     // Extract strided slice has the same layout as src.
     xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
 
@@ -716,11 +728,15 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
     xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
     xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
     // Extract and reduction results in scalars, so no result layout is needed.
-    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+    // Build multi-dim index into acc (sourceRank-1 dims, i.e. source shape with
+    // the reduction dim removed). Leading unit dims get index 0.
+    SmallVector<int64_t> accIdx(accRank, 0);
+    accIdx[accRank - 1] = i;
+    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
     Value reduction = vector::ReductionOp::create(
         rewriter, loc, kind, slice.getResult(), accExtract);
-    reductionResult =
-        vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+    reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
+                                               reductionResult, accIdx);
     // Insert op should have the same layout as the accumulator.
     xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
   }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 8b980c5083af3..3d4abc6cb7bd7 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -396,6 +396,47 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
 }
 
 
+// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim
+// CHECK:       %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<1x32xf32>
+// CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME:    -> (vector<1x2xf32>, vector<1x16x2xf32>, vector<1x2xf32>) {
+// CHECK:         %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<1x16x32xf32>
+// CHECK:         gpu.yield %{{.*}}, %[[SRC]], %[[ACC]] : vector<1x32xf32>, vector<1x16x32xf32>, vector<1x32xf32>
+// CHECK-NEXT:  }
+// CHECK:       %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:    {offsets = [0, 0, 0], sizes = [1, 16, 1], strides = [1, 1, 1]} : vector<1x16x2xf32> to vector<1x16x1xf32>
+// CHECK:       %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<1x16x1xf32> to vector<16xf32>
+// CHECK:       %[[T3:.*]] = vector.extract %[[W]]#2[0, 0] : f32 from vector<1x2xf32>
+// CHECK:       %[[T4:.*]] = vector.reduction <add>, %[[T2]], %[[T3]] : vector<16xf32> into f32
+// CHECK:       %[[T5:.*]] = vector.extract_strided_slice %[[W]]#1
+// CHECK-SAME:    {offsets = [0, 0, 1], sizes = [1, 16, 1], strides = [1, 1, 1]} : vector<1x16x2xf32> to vector<1x16x1xf32>
+// CHECK:       %[[T6:.*]] = vector.shape_cast %[[T5]] : vector<1x16x1xf32> to vector<16xf32>
+// CHECK:       %[[T7:.*]] = vector.extract %[[W]]#2[0, 1] : f32 from vector<1x2xf32>
+// CHECK:       %[[T8:.*]] = vector.reduction <add>, %[[T6]], %[[T7]] : vector<16xf32> into f32
+// CHECK:       %[[T9:.*]] = vector.from_elements %[[T4]], %[[T8]] : vector<1x2xf32>
+gpu.func @vector_multi_reduction_3d_leading_unit_dim(%laneid: index) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x2xf32>) {
+    %src = "some_def"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+      : () -> (vector<1x16x32xf32>)
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>}
+      dense<0.0>  : vector<1x32xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_operand_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
+        layout_operand_1 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>,
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
+      }
+      [1] : vector<1x16x32xf32> to vector<1x32xf32>
+    gpu.yield %1 : vector<1x32xf32>
+  }
+  "some_user_op"(%r) : (vector<1x2xf32>) -> ()
+  gpu.return
+}
+
+
 // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
 // CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
 // CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>



More information about the Mlir-commits mailing list