[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance multi-reduction layout propagation rules (PR #186308)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 12 21:05:02 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

This PR enhance the multi-reduction layout propagation: 
1. improve inst_data and lane_data to support fractional subgroup size
2. improve subgroup_layout/data setup to utilize the (nested) slice layout from consumer op

---
Full diff: https://github.com/llvm/llvm-project/pull/186308.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp (+89-46) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (-2) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir (+29-19) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 314bc78a3653f..2812efdaee27a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -437,75 +437,118 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     return DenseI32ArrayAttr::get(context, vec32);
   };
 
-  // Extract original plain layout for workgroup/subgroup size recovery
-  xegpu::SliceAttr consumerSliceLayout =
-      dyn_cast<xegpu::SliceAttr>(consumerLayout);
-  DistributeLayoutAttr plainLayout =
-      consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
-                          : consumerLayout;
+  // Helper lambda to check if the layout from consumer can be reused for the
+  // source shape
+  auto isLayoutCompatibleWithSrcShape =
+      [&](ArrayRef<int64_t> srcShape,
+          xegpu::DistributeLayoutAttr srcLayout) -> bool {
+    SmallVector<int64_t> sgLayout = srcLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> laneLayout = srcLayout.getEffectiveLaneLayoutAsInt();
+    if (static_cast<size_t>(srcLayout.getRank()) != srcShape.size())
+      return false;
+    for (size_t i = 0; i < srcShape.size(); i++) {
+      if (!sgLayout.empty() && srcShape[i] % sgLayout[i] != 0)
+        return false;
+      if (!laneLayout.empty() && srcShape[i] % laneLayout[i] != 0)
+        return false;
+    }
+    return true;
+  };
 
+  // Extract original plain layout for workgroup/subgroup size recovery
+  xegpu::DistributeLayoutAttr rootPlainLayout = consumerLayout;
+  while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(rootPlainLayout)) {
+    rootPlainLayout = sliceAttr.getParent();
+  }
+  auto sgLayoutVec = rootPlainLayout.getEffectiveSgLayoutAsInt();
+  const int workgroupSize = std::accumulate(
+      sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
   const int subgroupSize = uArch->getSubgroupSize();
   int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
 
-  xegpu::DistributeLayoutAttr srcLayout;
+  xegpu::SliceAttr consumerSliceLayout =
+      dyn_cast<xegpu::SliceAttr>(consumerLayout);
+  SmallVector<int64_t> consumerSgLayout =
+      consumerLayout.getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> consumerLaneLayout =
+      consumerLayout.getEffectiveLaneLayoutAsInt();
+  SmallVector<int64_t> consumerOrder = consumerLayout.getEffectiveOrderAsInt();
+  DenseI32ArrayAttr orderAttr = consumerLayout.getOrder();
 
+  xegpu::DistributeLayoutAttr srcLayout;
   if (layoutKind == xegpu::LayoutKind::Subgroup) {
-    auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
-    const int workgroupSize = std::accumulate(
-        sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
-    SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
-    SmallVector<int64_t> consumerSgLayout =
-        consumerLayout.getEffectiveSgLayoutAsInt();
-    int remainingSgCount = workgroupSize;
-    int consumerIdx = consumerSgLayout.size() - 1;
-
-    // First pass: Match consumer's layout on non-reduction dimensions
-    for (int i = srcRank - 1; i >= 0; i--) {
-      if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
-        sgLayout[i] = consumerSgLayout[consumerIdx];
-        assert((srcShape[i] % sgLayout[i] == 0) &&
-               "source shape not divisible by consumer sg_layout");
-        sgData[i] = srcShape[i] / sgLayout[i];
-        remainingSgCount /= sgLayout[i];
-        consumerIdx--;
+    if (consumerSliceLayout &&
+        consumerSliceLayout.getDims().asArrayRef().equals(reductionDims) &&
+        isLayoutCompatibleWithSrcShape(srcShape,
+                                       consumerSliceLayout.getParent())) {
+      int64_t sgDataValue = -1;
+      srcLayout = consumerSliceLayout.getParent();
+      SmallVector<int64_t> sgLayoutFromConsumer =
+          srcLayout.getEffectiveSgLayoutAsInt();
+      for (int dim = 0; dim < srcRank; dim++) {
+        sgDataValue = srcShape[dim] / sgLayoutFromConsumer[dim];
+        srcLayout = srcLayout.setDimData(dim, sgDataValue, -1, -1);
       }
-    }
+    } else {
 
-    // Second pass: Distribute remaining subgroups across reduction dimensions
-    for (int i = srcRank - 1; i >= 0; i--) {
-      if (llvm::is_contained(reductionDims, i)) {
-        sgLayout[i] =
-            std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
-        assert((srcShape[i] % sgLayout[i] == 0) &&
-               "source shape not divisible by sg_layout");
-        sgData[i] = srcShape[i] / sgLayout[i];
-        remainingSgCount /= sgLayout[i];
+      SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank), order(srcRank);
+      int remainingSgCount = workgroupSize;
+      int consumerIdx = consumerSgLayout.size() - 1;
+
+      // First pass: Match consumer's layout on non-reduction dimensions
+      for (int i = srcRank - 1; i >= 0; i--) {
+        if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
+          sgLayout[i] = consumerSgLayout[consumerIdx];
+          assert((srcShape[i] % sgLayout[i] == 0) &&
+                 "source shape not divisible by consumer sg_layout");
+          sgData[i] = srcShape[i] / sgLayout[i];
+          remainingSgCount /= sgLayout[i];
+          order[i] = consumerOrder[consumerIdx];
+          consumerIdx--;
+        }
       }
-    }
 
-    assert(remainingSgCount == 1 && "not all subgroups distributed");
-    srcLayout = xegpu::LayoutAttr::get(
-        context, toInt32Attr(sgLayout), toInt32Attr(sgData),
-        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
-        /*lane_data =*/nullptr, /*order =*/nullptr);
+      // Second pass: Distribute remaining subgroups across reduction dimensions
+      int64_t remainOrder = consumerSgLayout.size();
+      for (int i = srcRank - 1; i >= 0; i--) {
+        if (llvm::is_contained(reductionDims, i)) {
+          sgLayout[i] =
+              std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
+          assert((srcShape[i] % sgLayout[i] == 0) &&
+                 "source shape not divisible by sg_layout");
+          sgData[i] = srcShape[i] / sgLayout[i];
+          remainingSgCount /= sgLayout[i];
+          order[i] = remainOrder++;
+        }
+      }
 
+      assert(remainingSgCount == 1 && "not all subgroups distributed");
+      srcLayout = xegpu::LayoutAttr::get(
+          context, toInt32Attr(sgLayout), toInt32Attr(sgData),
+          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+          /*lane_data =*/nullptr, /*order =*/
+          (!orderAttr || orderAttr.empty()) ? nullptr : toInt32Attr(order));
+    }
   } else if (layoutKind == xegpu::LayoutKind::InstData) {
 
     SmallVector<int64_t> instData(srcRank, 1);
     instData[srcRank - 2] =
         std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
-    instData[srcRank - 1] = subgroupSize;
+    instData[srcRank - 1] =
+        std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
     srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
 
   } else if (layoutKind == xegpu::LayoutKind::Lane) {
 
     SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
-    laneLayout[srcRank - 1] = subgroupSize;
+    laneLayout[srcRank - 1] =
+        std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
     laneData[srcRank - 2] =
         std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
-    srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
-                                       toInt32Attr(laneData),
-                                       consumerLayout.getOrder());
+    srcLayout = xegpu::LayoutAttr::get(
+        context, toInt32Attr(laneLayout), toInt32Attr(laneData),
+        (!orderAttr || orderAttr.empty()) ? nullptr
+                                          : toInt32Attr(consumerOrder));
   }
 
   return xegpu::SliceAttr::get(context, srcLayout,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 57db05f4a3b74..97a30de42cf2e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1143,7 +1143,6 @@ void LayoutInfoPropagation::visitLoadMatrixOp(
   if (!hasParamsOfLayoutKind(anchorLayout)) {
     VectorType resVecTy =
         llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
-    assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
     const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
     if (!uArch)
       return;
@@ -1164,7 +1163,6 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
   } else {
     VectorType srcVecTy =
         llvm::cast<VectorType>(storeMatrix.getData().getType());
-    assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
     const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
     if (!uArch)
       return;
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 9ee3de4490727..d730d04c819fa 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -128,8 +128,7 @@ gpu.module @test {
 gpu.module @test {
 // CHECK-LABEL: vector_row_reduction
 // CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [1, 64]>, dims = [1]>}
-  gpu.func @vector_row_reduction(%src: memref<32x64xf32>, %dst: memref<32xf32>) kernel attributes
-      {known_block_size = array<i32: 1, 32, 1>} {
+  gpu.func @vector_row_reduction(%src: memref<32x64xf32>, %dst: memref<32xf32>) {
     %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
     %tdesc_src = xegpu.create_nd_tdesc %src : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32>
     %load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x64xf32> -> vector<32x64xf32>
@@ -144,8 +143,7 @@ gpu.module @test {
 // -----
 gpu.module @test {
 // CHECK-LABEL: vector_nest_reduction
-  gpu.func @vector_nest_reduction(%src: memref<32x128xf32>, %dst: memref<32xf32>) kernel attributes
-      {known_block_size = array<i32: 1, 32, 1>} {
+  gpu.func @vector_nest_reduction(%src: memref<32x128xf32>, %dst: memref<32xf32>) {
     %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
     %cst1 = arith.constant dense<0.000000e+00> : vector<32x128xf32>
     %tdesc_src = xegpu.create_nd_tdesc %src : memref<32x128xf32> -> !xegpu.tensor_desc<32x128xf32>
@@ -167,21 +165,33 @@ gpu.module @test {
 
 // -----
 gpu.module @test {
-  // CHECK-LABEL: broadcast_both_leadingdims_innerdims
-  gpu.func @broadcast_both_leadingdims_innerdims(%arg0: memref<32x2x192xf32>, %arg1: memref<32x2x192xf32>, %arg2: memref<32x2x192xf32>) kernel attributes {known_block_size = array<i32: 768, 1, 1>, known_grid_size = array<i32: 16, 1, 1>} {
-    // CHECK: arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<true> : vector<2x2x6x32xi1>
-    %cst = arith.constant dense<true> : vector<2x2x6x32xi1>
-    // CHECK: arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<1.000000e+00> : vector<2x2x6x32xf32>
-    %cst_0 = arith.constant dense<1.000000e+00> : vector<2x2x6x32xf32>
-    %intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<32x2x192xf32> -> index
-    %0 = arith.index_cast %intptr : index to i64
-    // CHECK: vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [0, 1]>, dims = [1]>} : vector<6xindex>
-    %1 = vector.step : vector<6xindex>
-    // CHECK: vector.shape_cast {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [0, 1]>} : vector<6xindex> to vector<6x1xindex>
-    %2 = vector.shape_cast %1 : vector<6xindex> to vector<6x1xindex>
-    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : vector<6x1xindex> to vector<2x2x6x32xindex>
-    %3 = vector.broadcast %2 : vector<6x1xindex> to vector<2x2x6x32xindex>
-    xegpu.store %cst_0, %0[%3], %cst <{layout = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>}> : vector<2x2x6x32xf32>, i64, vector<2x2x6x32xindex>, vector<2x2x6x32xi1>
+// CHECK-LABEL: vector_nest_reduction_with_nest_slice_layout
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>, dims = [1]>} dense<0.000000e+00> : vector<32xf32>
+// CHECK: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>} dense<0.000000e+00> : vector<32x128xf32>
+// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x128xf32> -> !xegpu.tensor_desc<32x128xf32, #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>>
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>}>
+// CHECK-SAME: -> vector<32x128xf32>
+// CHECK: %[[BCAST1:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>} : vector<32x128xf32> to vector<4x32x128xf32>
+// CHECK: %[[REDUCE1:.*]] = vector.multi_reduction <add>, %[[BCAST1]], %[[CST0]]
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>} [0] : vector<4x32x128xf32> to vector<32x128xf32>
+// CHECK: %[[REDUCE2:.*]] = vector.multi_reduction <add>, %[[REDUCE1]], %[[CST]]
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>, dims = [1]>} [1] : vector<32x128xf32> to vector<32xf32>
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 32]>, dims = [0]>, dims = [1]>} dense<true> : vector<32xi1>
+// CHECK: %[[OFFSET:.*]] = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 32]>, dims = [0]>, dims = [1]>} : vector<32xindex>
+// CHECK: xegpu.store %[[REDUCE2]], %{{.*}}[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: <{layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 32]>, dims = [0]>, dims = [1]>}>
+// CHECK-SAME: : vector<32xf32>, memref<32xf32>, vector<32xindex>, vector<32xi1>
+  gpu.func @vector_nest_reduction_with_nest_slice_layout(%src: memref<32x128xf32>, %dst: memref<32xf32>) {
+    %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+    %cst1 = arith.constant dense<0.000000e+00> : vector<32x128xf32>
+    %tdesc_src = xegpu.create_nd_tdesc %src : memref<32x128xf32> -> !xegpu.tensor_desc<32x128xf32>
+    %load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x128xf32> -> vector<32x128xf32>
+    %bcast1 = vector.broadcast %load: vector<32x128xf32> to vector<4x32x128xf32>
+    %bcast = vector.multi_reduction <add>, %bcast1, %cst1 [0]: vector<4x32x128xf32> to vector<32x128xf32>
+    %reduce = vector.multi_reduction <add>, %bcast, %cst [1] : vector<32x128xf32> to vector<32xf32>
+    %mask = arith.constant dense<1>: vector<32xi1>
+    %offset = vector.step : vector<32xindex>
+    xegpu.store %reduce, %dst[%offset], %mask {layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 32]>, dims = [0]>, dims = [1]>} : vector<32xf32>, memref<32xf32>, vector<32xindex>, vector<32xi1>
     gpu.return
   }
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/186308


More information about the Mlir-commits mailing list